コード例 #1
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=True,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
   """Applies Transformer MlpBlock module."""
   actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
   x = nn.Dense(
       inputs,
       mlp_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   x = nn.gelu(x)
   x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
   output = nn.Dense(
       x,
       actual_out_dim,
       dtype=dtype,
       kernel_init=kernel_init,
       bias_init=bias_init)
   output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
   return output
コード例 #2
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6),
           num_partitions=2):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     inputs_shape = inputs.shape
     inputs = inputs.reshape((-1, inputs_shape[-1]))
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     if num_partitions > 1:
         x = with_sharding_constraint(x, P(1, num_partitions))
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     output = output.reshape(inputs_shape[:-1] + (actual_out_dim, ))
     return output
コード例 #3
0
    def apply(self,
              embed: jnp.ndarray,
              lengths: jnp.ndarray,
              hidden_size: int = None,
              output_size: int = None,
              dropout: float = None,
              emb_dropout: float = None,
              train: bool = None):
        """Encodes the input sequence and makes a prediction using an MLP."""
        # embed <float32>[batch_size, seq_length, embedding_size]
        # lengths <int64>[batch_size]
        if train:
            embed = nn.dropout(embed, rate=emb_dropout)

        # Encode the sequence of embedding using an LSTM.
        hidden = LSTM(embed, lengths, hidden_size=hidden_size, name='lstm')
        if train:
            hidden = nn.dropout(hidden, rate=dropout)

        # Predict the class using an MLP.
        logits = MLP(hidden,
                     hidden_size=hidden_size,
                     output_size=output_size,
                     output_bias=False,
                     dropout=dropout,
                     name='mlp',
                     train=train)
        return logits
コード例 #4
0
ファイル: adabelief_vgg.py プロジェクト: cshallue/init2winit
def classifier(x, num_outputs, dropout_rate, deterministic):
    """Implements the classification portion of the network."""

    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = nn.Dense(x, 512)
    x = nn.relu(x)
    x = nn.Dense(x, num_outputs)
    return x
コード例 #5
0
  def apply(self,
            inputs,
            vocab_size,
            output_vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=True,
            dropout_rate=0.3,
            attention_dropout_rate=0.3):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the input vocabulary
      output_vocab_size: size of the output classes
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights

    Returns:
      output of a transformer decoder.

    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)

    x = inputs.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len))
    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=False,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
      )
    x = nn.LayerNorm(x)
    logits = nn.Dense(
        x,
        output_vocab_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    return logits
コード例 #6
0
    def apply(self,
              x,
              act,
              normalize,
              temb=None,
              out_ch=None,
              conv_shortcut=False,
              dropout=0.1,
              train=True,
              skip_rescale=False,
              init_scale=0.):
        B, H, W, C = x.shape
        out_ch = out_ch if out_ch else C
        h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))
        h = conv3x3(h, out_ch)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += nn.Dense(act(temb), out_ch,
                          kernel_init=default_init())[:, None, None, :]

        h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
        h = nn.dropout(h, dropout, deterministic=not train)
        h = conv3x3(h, out_ch, init_scale=init_scale)
        if C != out_ch:
            if conv_shortcut:
                x = conv3x3(x, out_ch)
            else:
                x = NIN(x, out_ch)

        if not skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
コード例 #7
0
 def apply(
         self,
         hidden_states,
         *,
         d_ff: int,
         dropout_rate: float = 0.0,
         intermediate_activation=nn.gelu,
         # TODO(kitaev): chunk_size hparam for chunking
         kernel_init=nn.initializers.xavier_uniform(),
         deterministic: bool = False):
     """Applies FeedForward module."""
     d_model = hidden_states.shape[-1]
     hidden_states = nn.Dense(hidden_states,
                              d_ff,
                              kernel_init=kernel_init,
                              name='intermediate')
     hidden_states = intermediate_activation(hidden_states)
     hidden_states = nn.Dense(hidden_states,
                              d_model,
                              kernel_init=kernel_init,
                              name='output')
     hidden_states = nn.dropout(hidden_states,
                                rate=dropout_rate,
                                deterministic=deterministic)
     return hidden_states
コード例 #8
0
 def apply(self,
           x,
           act,
           normalize,
           temb=None,
           out_ch=None,
           conv_shortcut=False,
           dropout=0.5,
           train=True):
     B, H, W, C = x.shape
     out_ch = out_ch if out_ch else C
     h = act(normalize(x))
     h = ddpm_conv3x3(h, out_ch)
     # Add bias to each feature map conditioned on the time embedding
     if temb is not None:
         h += nn.Dense(act(temb), out_ch,
                       kernel_init=default_init())[:, None, None, :]
     h = act(normalize(h))
     h = nn.dropout(h, dropout, deterministic=not train)
     h = ddpm_conv3x3(h, out_ch, init_scale=0.)
     if C != out_ch:
         if conv_shortcut:
             x = ddpm_conv3x3(x, out_ch)
         else:
             x = NIN(x, out_ch)
     return x + h
コード例 #9
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              attention_fn=nn.dot_product_attention,
              cache=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      attention_fn: dot product function to use inside attention.
      cache: Cache for decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=causal_mask,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             attention_fn=attention_fn,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
コード例 #10
0
ファイル: models.py プロジェクト: zhang-yd15/flax
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of heads.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout rate for attention weights.
      deterministic: bool, deterministic or not (to apply dropout).

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs, dtype=dtype)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
コード例 #11
0
  def apply(self,
            inputs,
            mlp_dim,
            inputs_masks=None,
            dtype=jnp.float32,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=True,
            layer_drop_p=None,
            **attention_kwargs):
    """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      inputs_masks: bool, input mask.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      layer_drop_p: probability of dropping a layer.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output after transformer encoder block.
    """

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs, dtype=dtype)
    x = nn.SelfAttention(
        x,
        dtype=dtype,
        inputs_kv=x,
        attention_axis=(1,),
        causal_mask=False,
        padding_mask=inputs_masks,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=attention_dropout_rate,
        **attention_kwargs)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    x = x * (1.0 - drop_pattern) + inputs

    # MLP block.
    y = nn.LayerNorm(x, dtype=dtype)
    y = MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dtype=dtype,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    return y * (1.0 - drop_pattern) + x
コード例 #12
0
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True):
        norm = get_norm(activation_f, normalization, train)

        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        penalty = 0
        y = x
        y = norm(y, name='norm1')
        if std_penalty_mult > 0:
            penalty += std_penalty(y)
        y = activation_f(y, features=y.shape[-1])
        y = conv(
            y,
            channels,
            (3, 3),
            strides,
            padding='SAME',
            name='conv1',
        )
        y = norm(y, name='norm2')
        if std_penalty_mult > 0:
            penalty += std_penalty(y)
        y = activation_f(y, features=y.shape[-1])
        if dropout_rate > 0.0:
            y = nn.dropout(y, dropout_rate, deterministic=not train)
        y = conv(y, channels, (3, 3), padding='SAME', name='conv2')

        if use_residual == 1:
            # Apply an up projection in case of channel mismatch
            if (x.shape[-1] != channels) or strides != (1, 1):
                x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME')
            result = x + y
        elif use_residual == 2:
            # Unit variance preserving residual.
            if (x.shape[-1] != channels) or strides != (1, 1):
                x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME')

            result = (x + y) / jnp.sqrt(
                1**2 + 1**2)  # Sum of independent normals.
        else:
            result = y

        return result, penalty
コード例 #13
0
 def apply(self,
           hidden_states,
           mask=None,
           *,
           d_qkv=64,
           attention_dropout_rate=0.0,
           output_dropout_rate=0.0,
           deterministic=False,
           kernel_init=nn.linear.default_kernel_init,
           output_kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.zeros,
           bias=True):
     """Applies attention for a single batch element and head."""
     d_model = hidden_states.shape[-1]
     dense = nn.DenseGeneral.partial(axis=-1,
                                     features=(d_qkv, ),
                                     kernel_init=kernel_init,
                                     bias_init=bias_init,
                                     bias=bias)
     query, key, value = (dense(hidden_states, name='query'),
                          dense(hidden_states, name='key'),
                          dense(hidden_states, name='value'))
     attention_scores = jnp.einsum('TN,FN->FT', key, query)
     attention_scores = attention_scores / jnp.sqrt(d_qkv)
     if mask is not None:
         padding_mask = (1.0 - mask[None, :]) * NEG_INFINITY
         attention_scores = attention_scores + padding_mask
     attention_scores = nn.softmax(attention_scores)
     attention_probs = nn.dropout(attention_scores,
                                  rate=attention_dropout_rate,
                                  deterministic=deterministic)
     hidden_states = jnp.einsum('FT,TH->FH', attention_probs, value)
     hidden_states = nn.linear.DenseGeneral(hidden_states,
                                            features=d_model,
                                            axis=(-1, ),
                                            kernel_init=output_kernel_init,
                                            name='output')
     hidden_states = nn.dropout(hidden_states,
                                rate=output_dropout_rate,
                                deterministic=deterministic)
     return hidden_states
コード例 #14
0
    def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout):
        with nn.stochastic(jax.random.PRNGKey(0)):
            x = SAGEConv(g, x, in_feats, hidden_feats)

            for idx in range(num_layers-2):
                x = SAGEConv(g, x, hidden_feats, hidden_feats)
                x = nn.BatchNorm(x)
                x = nn.dropout(x, rate=dropout)

            x = SAGEConv(g, x, hidden_feats, out_feats)

        return jax.nn.log_softmax(x, axis=-1)
コード例 #15
0
 def apply(self,
           inputs: jnp.ndarray,
           hidden_size: int = None,
           output_size: int = None,
           output_bias: bool = False,
           dropout: float = None,
           train: bool = None):
     # inputs.shape = <float32>[batch_size, seq_length, hidden_size]
     hidden = nn.Dense(inputs, hidden_size, name='hidden')
     hidden = nn.tanh(hidden)
     if train:
         hidden = nn.dropout(hidden, rate=dropout)
     output = nn.Dense(hidden, output_size, bias=output_bias, name='output')
     return output
コード例 #16
0
    def apply(self,
              x,
              act,
              normalize,
              up=False,
              down=False,
              temb=None,
              out_ch=None,
              dropout=0.1,
              fir=False,
              fir_kernel=[1, 3, 3, 1],
              train=True,
              skip_rescale=True,
              init_scale=0.):
        B, H, W, C = x.shape
        out_ch = out_ch if out_ch else C
        h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))

        if up:
            if fir:
                h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2)
                x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2)
            else:
                h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
                x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
        elif down:
            if fir:
                h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2)
                x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2)
            else:
                h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
                x = up_or_down_sampling.naive_downsample_2d(x, factor=2)

        h = conv3x3(h, out_ch)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += nn.Dense(act(temb), out_ch,
                          kernel_init=default_init())[:, None, None, :]

        h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
        h = nn.dropout(h, dropout, deterministic=not train)
        h = conv3x3(h, out_ch, init_scale=init_scale)
        if C != out_ch or up or down:
            x = conv1x1(x, out_ch)

        if not skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)
コード例 #17
0
    def apply(self,
              inputs,
              mlp_dim,
              dtype=jnp.float32,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=True,
              **attention_kwargs):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs, dtype=dtype)
        x = modified_attention.SelfAttention_modified(
            x,
            dtype=dtype,
            inputs_kv=x,
            attention_axis=(1, ),
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=deterministic,
            dropout_rate=attention_dropout_rate,
            **attention_kwargs)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
コード例 #18
0
ファイル: models.py プロジェクト: dandelin/vision_transformer
    def apply(
        self,
        inputs,
        num_layers,
        mlp_dim,
        inputs_positions=None,
        dropout_rate=0.1,
        train=False,
        **attention_kwargs,
    ):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block
      inputs_positions: input subsequence positions for packed examples.
      dropout_rate: dropout rate
      train: if it is training,
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 3  # (batch, len, emb)

        x = AddPositionEmbs(
            inputs,
            inputs_positions=inputs_positions,
            posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
            name="posembed_input",
        )
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        # Input Encoder
        for lyr in range(num_layers):
            x = Encoder1DBlock(
                x,
                mlp_dim=mlp_dim,
                dropout_rate=dropout_rate,
                deterministic=not train,
                name=f"encoderblock_{lyr}",
                **attention_kwargs,
            )
        encoded = nn.LayerNorm(x, name="encoder_norm")

        return encoded
コード例 #19
0
  def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True):
    batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                      momentum=0.9, epsilon=1e-5)

    y = batch_norm(x, name='bn1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = batch_norm(y, name='bn2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
コード例 #20
0
def GatedResnet(inputs,
                aux=None,
                conv_module=None,
                nonlinearity=concat_elu,
                dropout_p=0.):
    c = inputs.shape[-1]
    y = conv_module(nonlinearity(inputs), c)
    if aux is not None:
        y = nonlinearity(y + ConvOneByOne(nonlinearity(aux), c))

    if dropout_p > 0:
        y = nn.dropout(y, dropout_p)

    # Set init_scale=0.1 so that the res block is close to the identity at
    # initialization.
    a, b = np.split(conv_module(y, 2 * c, init_scale=0.1), 2, axis=-1)
    return inputs + a * nn.sigmoid(b)
コード例 #21
0
ファイル: modeling.py プロジェクト: nikitakit/flax_bert
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              labels=None,
              *,
              config,
              n_classes,
              deterministic=False):
        """Applies BERT for sequence classification."""
        unused_sequence_output, pooled_output = BertModel(
            input_ids,
            input_mask,
            type_ids,
            config=config,
            deterministic=deterministic,
            name='bert')
        pooled_output = nn.dropout(pooled_output,
                                   rate=config.hidden_dropout_prob,
                                   deterministic=deterministic)
        logits = layers.OutputProjection(pooled_output,
                                         n_out=n_classes,
                                         kernel_init=get_kernel_init(config),
                                         name='classification')

        if labels is None:
            return logits
        elif logits.shape[-1] == 1:
            # Regression task
            loss = jnp.mean((logits[..., 0] - labels)**2)
            return {'loss': loss}
        else:
            # Classification task
            logits = nn.log_softmax(logits)
            loss = -jnp.mean(
                jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1))
            return {'loss': loss}
コード例 #22
0
ファイル: modeling.py プロジェクト: nikitakit/flax_bert
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='word_embeddings')
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_position_embeddings,
            posemb_init=get_kernel_init(config),
            name='position_embeddings')
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='type_embeddings')

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name='embeddings_layer_norm')
        embeddings = nn.dropout(embeddings,
                                rate=config.hidden_dropout_prob,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.intermediate_size,
            dropout_rate=config.hidden_dropout_prob,
            intermediate_activation=get_hidden_activation(config),
            kernel_init=get_kernel_init(config))

        attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_attention_heads,
            num_parallel_heads=None,
            d_qkv=config.hidden_size // config.num_attention_heads,
            attention_dropout_rate=config.attention_probs_dropout_prob,
            output_dropout_rate=config.hidden_dropout_prob,
            kernel_init=get_kernel_init(config),
            output_kernel_init=get_kernel_init(config))

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        for layer_num in range(config.num_hidden_layers):
            hidden_states = layers.TransformerBlock(
                hidden_states,
                mask,
                feed_forward=feed_forward,
                attention=attention,
                deterministic=deterministic,
                name=f'encoder_layer_{layer_num}')

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.hidden_size,
                                 kernel_init=get_kernel_init(config),
                                 name='pooler')
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
コード例 #23
0
  def apply(self,
            inputs,
            inputs_spatial_positions,
            inputs_scale_positions,
            inputs_masks,
            spatial_pos_grid_size,
            num_scales,
            num_layers,
            mlp_dim,
            use_sinusoid_pos_emb=False,
            use_scale_emb=True,
            dropout_rate=0.1,
            train=False,
            dtype=jnp.float32,
            stochastic_layer_drop_rate=0.0,
            **attention_kwargs):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_spatial_positions: input spatial positions for each embedding.
      inputs_scale_positions: input scale positions for each embedding.
      inputs_masks: bool, input mask.
      spatial_pos_grid_size: spatial positional encoding hash grid size.
      num_scales: number of scales input.
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block.
      use_sinusoid_pos_emb: whether to use Sinusoidal Positional Embedding.
      use_scale_emb: use scale embedding.
      dropout_rate: dropout rate
      train: if it is training,
      dtype: dtype of activations.
      stochastic_layer_drop_rate: probability of dropping a layer linearly grows
        from 0 to the provided value. Our implementation of stochastic depth
        follows timm library, which does per-example layer dropping and uses
        independent dropping patterns for each skip-connection.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output of a transformer encoder.
    """
    assert inputs.ndim == 3  # (batch, len, emb)
    dtype = jax.dtypes.canonicalize_dtype(dtype)

    if not use_sinusoid_pos_emb:
      x = AddHashSpatialPositionEmbs(
          inputs,
          spatial_pos_grid_size,
          inputs_positions=inputs_spatial_positions,
          posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
          name="posembed_input")
    else:
      pos_emb_shape = (1, spatial_pos_grid_size * spatial_pos_grid_size,
                       inputs.shape[2])
      pe = get_sinusoid_encoding(pos_emb_shape[1], pos_emb_shape[2])
      pe = jnp.expand_dims(pe, axis=0)
      x = inputs + jnp.take(pe[0], inputs_spatial_positions, axis=0)

    if use_scale_emb:
      x = AddScaleEmbs(
          x,
          num_scales=num_scales,
          inputs_positions=inputs_scale_positions,
          scale_emb_init=nn.initializers.normal(stddev=0.02),
          name="scaleembed_input")

    n, _, c = x.shape
    cls = self.param("cls", (1, 1, c), nn.initializers.zeros)
    cls = jnp.tile(cls, [n, 1, 1])
    x = jnp.concatenate([cls, x], axis=1)

    cls_mask = jnp.ones((n, 1), dtype=inputs_masks.dtype)
    inputs_masks = jnp.concatenate([cls_mask, inputs_masks], axis=1)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    # Input Encoder
    for lyr in range(num_layers):
      layer_drop_p = (lyr / max(num_layers - 1, 1)) * stochastic_layer_drop_rate
      x = Encoder1DBlock(
          x,
          mlp_dim=mlp_dim,
          inputs_masks=inputs_masks,
          dropout_rate=dropout_rate,
          deterministic=not train,
          name=f"encoderblock_{lyr}",
          dtype=dtype,
          layer_drop_p=layer_drop_p,
          **attention_kwargs)
    encoded = nn.LayerNorm(x, name="encoder_norm")

    return encoded
コード例 #24
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              shift=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              cache=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output of a transformer decoder.
    """
        padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[...,
                                                                       None]
        assert inputs.ndim == 2  # (batch, len)
        x = inputs
        if shift:
            x = shift_right(x)
        x = x.astype('int32')
        x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
        x = AddPositionEmbs(x,
                            max_len=max_len,
                            posemb_init=sinusoidal_init(max_len=max_len),
                            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        for _ in range(num_layers):
            x = Transformer1DBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=True,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=cache,
            )
        x = nn.LayerNorm(x)
        logits = nn.Dense(x,
                          vocab_size,
                          kernel_init=nn.initializers.xavier_uniform(),
                          bias_init=nn.initializers.normal(stddev=1e-6))
        return logits
コード例 #25
0
  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            causal=True,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            normalizer='layer_norm',
            attention_fn=None,
            cache=None,
            pad_token=0):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      causal: Whether to apply causal masking.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      normalizer: One of 'batch_norm', 'layer_norm', 'none'
      attention_fn: Attention function to use. If None, defaults to
        nn.dot_product_attention.
      cache: flax autoregressive cache for fast decoding.
      pad_token: Indicates which input tokens are padded.

    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      if not causal:
        raise ValueError('Cannot have shift=True and causal=False')
      x = shift_right(x)
    x = x.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
    x = AddPositionEmbs(
        x,
        max_len=max_len,
        posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=causal,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          train=train,
          attention_fn=attention_fn,
          cache=cache,
          normalizer=normalizer,
      )
    if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
      maybe_normalize = model_utils.get_normalizer(normalizer, train)
      x = maybe_normalize(x)
    logits = nn.Dense(
        x,
        vocab_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    return logits
コード例 #26
0
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            train=True,
            normalizer='layer_norm',
            attention_fn=None,
            cache=None):
    """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      train: bool: if model is training.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      attention_fn: Attention function to use. If None, defaults to
        nn.dot_product_attention.
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

    # Attention block.
    assert inputs.ndim == 3
    if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']:
      maybe_pre_normalize = model_utils.get_normalizer(normalizer, train)
      maybe_post_normalize = model_utils.get_normalizer('none', train)
    elif normalizer == 'post_layer_norm':
      maybe_pre_normalize = model_utils.get_normalizer('none', train)
      maybe_post_normalize = model_utils.get_normalizer(normalizer, train)
    else:
      raise ValueError('Unsupported normalizer: {}'.format(normalizer))

    x = maybe_pre_normalize(inputs)

    if attention_fn is None:
      attention_fn = nn.dot_product_attention
    x = nn.SelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        attention_fn=attention_fn,
        dropout_rate=attention_dropout_rate,
        deterministic=not train,
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    x = x + inputs
    x = maybe_post_normalize(x)

    # MLP block.
    y = maybe_pre_normalize(x)
    y = MlpBlock(
        y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train)
    res = x + y

    return maybe_post_normalize(res)
コード例 #27
0
    def apply(self,
              x,
              *,
              self_attention_module,
              dim_intermediate,
              is_training,
              dropout_rate=0.1,
              use_pre_layernorm=False,
              layernorm_epsilon=1e-6,
              with_aux_outputs=True):
        """Compute self-attention with a feed-forward network on top.

    Args:
      x: Input representations.
      self_attention_module: Self-Attention layer.
      dim_intermediate: Size of the intermediate layer of the feed forward.
      is_training: Wether to enable dropout.
      dropout_rate: Dropout probability.
      use_pre_layernorm: Use pre layer norm from
        https://arxiv.org/abs/2002.04745.
      layernorm_epsilon: Epsilon parameter for all the layer norms.
      with_aux_outputs: Whether the self_attention_module has an aux output.

    Returns:
      New representations in a jnp.array of same shape as `x`.
    """
        dim_hidden = x.shape[-1]
        use_pre_ln = use_pre_layernorm
        use_post_ln = not use_pre_ln

        def apply_ln_if(pred, x, name):
            if pred:
                return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name)
            else:
                return x

        # attention
        x = apply_ln_if(use_pre_ln, x, "ln_pre_att")
        x_att = self_attention_module(x)
        if with_aux_outputs:
            x_att, output_aux = x_att

        # dropout norm and add
        x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training)
        x = x + x_att
        x = apply_ln_if(use_post_ln, x, "ln_post_att")

        # feed forward
        x_ffn = x
        x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn")
        x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1")
        x_ffn = jax.nn.relu(x_ffn)
        x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2")

        # dropout norm and add
        x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training)
        x = x + x_ffn
        x = apply_ln_if(use_post_ln, x, "ln_post_ffn")

        if with_aux_outputs:
            output = x, output_aux
        else:
            output = x
        return output
コード例 #28
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='CLS',
              num_classes=10,
              block_size=_DEFAULT_BLOCK_SIZE):
        """Applies BigBird transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.
      block_size: Size of attention blocks.

    Returns:
      output of a transformer encoder or logits if classifier_mode is true.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)

        if classifier and classifier_pool == 'CLS':
            cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros)
            cls = jnp.tile(cls, [x.shape[0], 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
            max_len += 1
            src_padding_mask = jnp.concatenate(
                [src_padding_mask[:, :1], src_padding_mask], axis=1)

        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = BigBirdBlock(x,
                             qkv_dim=qkv_dim,
                             mlp_dim=mlp_dim,
                             num_heads=num_heads,
                             dtype=dtype,
                             padding_mask=src_padding_mask,
                             inputs_segmentation=inputs_segmentation,
                             dropout_rate=dropout_rate,
                             attention_dropout_rate=attention_dropout_rate,
                             deterministic=not train,
                             block_size=block_size,
                             connectivity_seed=lyr,
                             name=f'encoderblock_{lyr}')
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            encoded = common_layers.classifier_head(
                encoded, num_classes, mlp_dim, pooling_mode=classifier_pool)
        return encoded
コード例 #29
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              block_size=_DEFAULT_BLOCK_SIZE,
              connectivity_seed=None):
        """Applies BigBirdBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      block_size: Size of attention blocks.
      connectivity_seed: Optional seed for random block sparse attention.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = bigbird_attention.BigBirdSelfAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache,
            block_size=block_size,
            connectivity_seed=connectivity_seed)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
コード例 #30
0
    def apply(self,
              encoded,
              src_padding_mask,
              targets,
              output_vocab_size,
              targets_positions=None,
              inputs_segmentation=None,
              targets_segmentation=None,
              tgt_padding_mask=None,
              shared_embedding=None,
              logits_via_embedding=False,
              shift=True,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=True,
              cache=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              num_partitions=1):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      output_vocab_size: size of the vocabulary.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.
      shared_embedding: a shared embedding matrix to use.
      logits_via_embedding: bool: whether final logit transform shares
        embedding weights.
      shift: whether to shift or not (for fast decoding).
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      cache: flax attention cache for fast decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      num_partitions: number of ways to partition (i.e. how many devices
        to run across).

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)

        if use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[..., None]

        # Target Embedding
        if shared_embedding is None:
            output_embed = Embed.shared(
                num_embeddings=output_vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5),
                dtype=dtype,
                num_partitions=num_partitions)()
        else:
            output_embed = shared_embedding

        y = targets.astype('int32')
        if shift:
            y = shift_right(y)
        y = output_embed[y] * jnp.sqrt(emb_dim)
        y = y.astype(dtype)
        y = AddPositionEmbs(y,
                            inputs_positions=targets_positions,
                            cache=cache,
                            name='posembed_targets')
        y = nn.dropout(y, rate=dropout_rate, deterministic=not train)

        # Target-Input Decoder
        for lyr in range(num_layers):
            y = EncoderDecoder1DBlock(
                y,
                encoded,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=tgt_padding_mask,
                key_padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                targets_segmentation=targets_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=cache,
                num_partitions=num_partitions,
                name=f'encoderdecoderblock_{lyr}')
        y = nn.LayerNorm(y, dtype=dtype, name='encoderdecoder_norm')
        y = y.reshape((-1, y.shape[-1]))

        # Decoded Logits
        if logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = lax.dot_general(y, output_embed,
                                     (((y.ndim - 1, ), (1, )), ((), ())))
        else:
            logits = nn.Dense(y,
                              output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')
        return logits