Ejemplo n.º 1
0
    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.max_target_positions = self.config.max_position_embeddings
        self.embed_scale = math.sqrt(
            self.config.d_model) if self.config.scale_embedding else 1.0

        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        self.embed_positions = create_sinusoidal_positions(
            self.config.max_position_embeddings + self.offset, embed_dim)
        self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype)
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
Ejemplo n.º 2
0
    def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.

    Returns:
      output of a transformer encoder.
    """
        cfg = self.config
        assert inputs.ndim == 2  # (batch, len)

        # Input Embedding
        if self.shared_embedding is None:
            input_embed = nn.Embed(
                num_embeddings=cfg.vocab_size,
                features=cfg.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')(
            x, inputs_positions=inputs_positions)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        x = x.astype(cfg.dtype)

        # Input Encoder
        for lyr in range(cfg.num_layers):
            x = Encoder1DBlock(config=cfg,
                               name=f'encoderblock_{lyr}')(x, encoder_mask)

        encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

        return encoded
Ejemplo n.º 3
0
    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        embed_dim = self.config.hidden_size
        self.padding_idx = self.config.pad_token_id
        self.max_target_positions = self.config.max_position_embeddings

        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            self.config.word_embed_proj_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )

        self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
            self.config.max_position_embeddings,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )

        if self.config.word_embed_proj_dim != self.config.hidden_size:
            self.project_in = nn.Dense(self.config.hidden_size, use_bias=False)
            self.project_out = nn.Dense(self.config.word_embed_proj_dim,
                                        use_bias=False)

        else:
            self.project_in = None
            self.project_out = None

        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
        # with checkpoints that have been fine-tuned before transformers v4.20.1
        # see https://github.com/facebookresearch/metaseq/pull/164
        if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
            self.final_layer_norm = nn.LayerNorm(dtype=self.dtype,
                                                 epsilon=1e-05)
        else:
            self.final_layer_norm = None

        self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
    def __call__(self, x):
        out = {}

        embedding = nn.Embed(num_embeddings=self.vocab_size,
                             features=self.width)
        x = out['embedded'] = embedding(x)

        # Add posemb
        n, l, d = x.shape  # pylint: disable=unused-variable
        x = x + self.param('pos_embedding',
                           nn.initializers.normal(stddev=1 / jnp.sqrt(d)),
                           (1, l, d), x.dtype)

        x = models_vit.Encoder(num_layers=self.num_layers,
                               mlp_dim=self.mlp_dim,
                               num_heads=self.num_heads,
                               dropout_rate=self.dropout_rate,
                               attention_dropout_rate=0,
                               add_position_embedding=False)(x, train=False)

        x = out['pre_logits'] = x[:, -1, :]  # note that we take *last* token
        x = out['logits'] = nn.Dense(self.num_classes, name='head')(x)

        return x, out
Ejemplo n.º 5
0
  def __call__(self,
               inputs,
               outputs):
    """Applies Transformer model to encode the IO specification.

    Args:
      inputs: input data [batch_size, num_io, length]
      outputs: output data [batch_size, num_io, length2]

    Returns:
      Encoded IO data `[batch_size, num_io, length2, dim]`
    """
    cfg = self.config

    # Inputs and outputs shared embeddings.
    embed = nn.Embed(
        num_embeddings=cfg.vocab_size,
        features=cfg.emb_dim,
        embedding_init=nn.initializers.normal(stddev=1.0),
        name='embed')

    if not cfg.use_relative_attention:
      pos_emb = AddPositionEmbs(config=cfg, cache=False, name='posembed_io')

    x = inputs.astype('int32')
    y = outputs.astype('int32')

    # Make attention masks.
    inputs_encoder_mask = nn.make_attention_mask(
        x > 0, x > 0, dtype=cfg.dtype)
    outputs_encoder_mask = nn.make_attention_mask(
        y > 0, y > 0, dtype=cfg.dtype)
    encoder_decoder_mask = nn.make_attention_mask(
        y > 0, x > 0, dtype=cfg.dtype)

    # Embed inputs.
    x = embed(x)
    if not cfg.use_relative_attention:
      x = pos_emb(x)
    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)

    x = x.astype(cfg.dtype)
    for lyr in range(cfg.num_layers):
      x = EncoderBlock(   # Attend to inputs.
          config=cfg,
          bidirectional_attention=True,
          num_relative_position_buckets=(
              cfg.num_input_relative_position_buckets),
          max_distance=cfg.max_input_distance,
          name=f'encoderblock_{lyr}')(x, inputs_encoder_mask)
    x = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

    # Embed outputs.
    y = embed(y)
    if not cfg.use_relative_attention:
      y = pos_emb(y)
    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)

    encode_decoder_cfg = cfg.replace(decode=False)
    for lyr in range(cfg.num_layers):
      y = EncoderDecoderBlock(   # Double attend to inputs and outputs.
          config=encode_decoder_cfg,
          bidirectional_attention=True,
          num_relative_position_buckets=(
              cfg.num_output_relative_position_buckets),
          max_distance=cfg.max_output_distance,
          relative_cross_attention=cfg.use_relative_attention,
          bidirectional_cross_attention=True,
          num_relative_position_buckets_cross_attention=(
              cfg.num_input_cross_output_relative_position_buckets),
          max_distance_cross_attention=cfg.max_input_cross_output_distance,
          name=f'encoderdecoderblock_{lyr}')(
              y, x, outputs_encoder_mask, encoder_decoder_mask)
    y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

    return y
Ejemplo n.º 6
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask = None,
               encoder_decoder_mask = None,
               decoder_relative_position = None,
               encoder_decoder_relative_position = None):
    """Applies Transformer to decode the targets.

    Args:
      targets: target outputs.
      encoded: encoded input data from encoder [batch, ..., length, mlp_dim].
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask
      decoder_relative_position: decoder relative positions tensor
          `[batch_sizes..., length2, length2]'
      encoder_decoder_relative_position: encoder-decoder relative tensor
          `[batch_sizes..., length2, length]'

    Returns:
      output of a transformer decoder.
    """
    cfg = self.config

    assert encoded.ndim == targets.ndim + 1

    output_embed = nn.Embed(
        num_embeddings=cfg.output_vocab_size,
        features=cfg.emb_dim,
        embedding_init=nn.initializers.normal(stddev=1.0),
        name='embed_output')

    heads = dict()
    y = targets.astype('int32')
    if cfg.shift:
      y = shift_right(y, cfg.bos_token)

    y = output_embed(y)
    if not cfg.use_relative_attention:
      y = AddPositionEmbs(config=cfg, cache=cfg.decode,
                          name='posembed_output')(y)
    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)

    y = y.astype(cfg.dtype)
    # Target-Input Decoder
    for lyr in range(cfg.num_layers):
      y = EncoderDecoderBlock(
          config=cfg,
          bidirectional_attention=False,
          num_relative_position_buckets=(
              cfg.num_program_relative_position_buckets),
          max_distance=cfg.max_program_distance,
          # relative_cross_attention=cfg.use_relative_attention,
          relative_cross_attention=False,
          bidirectional_cross_attention=True,
          num_relative_position_buckets_cross_attention=(
              cfg.num_program_cross_embed_relative_position_buckets),
          max_distance_cross_attention=cfg.max_program_cross_embed_distance,
          name=f'encoderdecoderblock_{lyr}')(
              y, encoded, decoder_mask, encoder_decoder_mask,
              decoder_relative_position, encoder_decoder_relative_position)
    y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

    heads['output_emb'] = y * (
        jnp.where(targets > 0, 1, 0).astype(jnp.float32)[Ellipsis, None])

    logits = nn.Dense(
        cfg.output_vocab_size,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        name='logitdense')(y)
    heads['logits'] = logits
    if cfg.output_head:
      return heads[cfg.output_head]
    else:
      return heads  # Return both output embeddings and logits.
Ejemplo n.º 7
0
    def __call__(self, inputs, temb, train, context, permutations):
        """Applies Transformer model on the inputs.

    Args:
      inputs: Input data.
      temb: The time embedding.
      train: Is the model training?
      context: A context to condition on.
      permutations: A batch of permutations that specifies generation order.

    Returns:
      Output of a transformer decoder.
    """
        cfg = self.config
        assert inputs.ndim == 2  # (batch, len)
        deterministic = not train

        # Permutations give the permutation order, for XLNet style training only. It
        # is important that permutations are applied _before shifting_. For this
        # reason, we also have to deal with the positional embeddings seperately
        # at a later point.
        if permutations is not None:
            assert cfg.is_causal
            assert permutations.shape == inputs.shape

            # Use the permutations to act on the inputs.
            inputs = util_fns.batch_permute(inputs, permutations)

        # Target Embedding
        embedding_layer = nn.Embed(
            num_embeddings=cfg.output_vocab_size,
            features=cfg.emb_dim,
            embedding_init=nn.initializers.normal(stddev=1.0))

        # Concatenate context if available.
        if context is not None:
            assert cfg.context_length == context.shape[
                1], f'{cfg.context_length} != {context.shape[1]} for {context.shape}'
            inputs = jnp.concatenate([context, inputs], axis=1)

        y = inputs.astype('int32')

        if cfg.is_causal:
            logging.info('Using causal Transformer')
            decoder_mask = nn.make_causal_mask(inputs, dtype=cfg.dtype)
        else:
            logging.info('Using fully connected (non-causal) Transformer')
            decoder_mask = None

        if cfg.is_causal:
            y = shift_inputs(y)
        y = embedding_layer(y)

        y = AddPositionEmbs(config=cfg, name='add_posemb')(y, permutations)

        y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=deterministic)

        y = y.astype(cfg.dtype)

        # Target-Input Decoder
        for lyr in range(cfg.num_layers):
            y = EncoderDecoder1DBlock(config=cfg,
                                      name=f'encoderdecoderblock_{lyr}')(
                                          y,
                                          temb,
                                          deterministic,
                                          decoder_mask=decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        logits = nn.Dense(cfg.output_vocab_size,
                          dtype=cfg.dtype,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init,
                          name='logitdense')(y)

        if context is not None:
            # Take only predictions for inputs, not context.
            logits = logits[:, cfg.context_length:]

        if permutations is not None:
            assert cfg.is_causal
            # Apply the inverse permutation to the logits.
            inv_permutations = util_fns.compute_batch_inverse_permute(
                permutations)
            logits = util_fns.batch_permute(logits, inv_permutations)

        return logits
Ejemplo n.º 8
0
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies Transformer to decode the targets.

    Args:
      targets: target outputs.
      encoded: encoded input data from encoder [batch, ..., length, mlp_dim].
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask

    Returns:
      output of a transformer decoder.
    """
        cfg = self.config

        assert encoded.ndim == targets.ndim + 1

        output_embed = nn.Embed(
            num_embeddings=cfg.output_vocab_size,
            features=cfg.emb_dim,
            embedding_init=nn.initializers.normal(stddev=1.0),
            name='embed_output')

        if cfg.use_relative_attention:
            attention_fn = functools.partial(
                relative_attention.RelativeMultiHeadDotProductAttention,
                num_relative_position_buckets=cfg.
                num_relative_position_buckets,
                causal=False)
            self_attention_fn = functools.partial(
                relative_attention.RelativeSelfAttention,
                num_relative_position_buckets=cfg.
                num_relative_position_buckets,
                causal=True)
        else:
            attention_fn = nn.MultiHeadDotProductAttention
            self_attention_fn = nn.SelfAttention

        heads = dict()
        y = targets.astype('int32')
        if cfg.shift:
            y = shift_right(y, cfg.bos_token)

        y = output_embed(y)
        if not cfg.use_relative_attention:
            y = AddPositionEmbs(config=cfg,
                                cache=cfg.decode,
                                name='posembed_output')(y)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)

        y = y.astype(cfg.dtype)
        # Target-Input Decoder
        for lyr in range(cfg.num_layers):
            y = EncoderDecoderBlock(
                config=cfg,
                dot_product_attention_fn=attention_fn,
                self_attention_fn=self_attention_fn,
                name=f'encoderdecoderblock_{lyr}')(y, encoded, decoder_mask,
                                                   encoder_decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        heads['output_emb'] = y * (jnp.where(targets > 0, 1, 0).astype(
            jnp.float32)[Ellipsis, None])

        logits = nn.Dense(cfg.output_vocab_size,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init,
                          name='logitdense')(y)
        heads['logits'] = logits
        if cfg.output_head:
            return heads[cfg.output_head]
        else:
            return heads  # Return both output embeddings and logits.
Ejemplo n.º 9
0
    def __call__(self,
                 encoded,
                 targets,
                 targets_positions=None,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output of a transformer decoder.
    """
        cfg = self.config

        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=cfg.output_vocab_size,
                features=cfg.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not cfg.decode:
            y = shift_right(y)
        y = output_embed(y)
        y = AddPositionEmbs(config=cfg,
                            decode=cfg.decode,
                            name='posembed_output')(
                                y, inputs_positions=targets_positions)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)

        y = y.astype(cfg.dtype)

        # Target-Input Decoder
        for lyr in range(cfg.num_layers):
            y = EncoderDecoder1DBlock(
                config=cfg, name=f'encoderdecoderblock_{lyr}')(
                    y,
                    encoded,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=encoder_decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        # Decoded Logits
        if cfg.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(cfg.output_vocab_size,
                              dtype=cfg.dtype,
                              kernel_init=cfg.kernel_init,
                              bias_init=cfg.bias_init,
                              name='logitdense')(y)
        return logits
Ejemplo n.º 10
0
 def test_embed_hash(self):
     self.assertEqual(hash(nn.Embed(2, 3)), hash(nn.Embed(2, 3)))
     self.assertNotEqual(hash(nn.Embed(3, 4)), hash(nn.Embed(2, 3)))
Ejemplo n.º 11
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)
Ejemplo n.º 12
0
 def setup(self):
     self.embed = nn.Embed(
         num_embeddings=self.num_embeddings,
         features=self.features,
         embedding_init=self.embedding_init,
     )
Ejemplo n.º 13
0
    def __call__(self,
                 encoded,
                 targets,
                 targets_positions=None,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

      train: whether it is training.

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)
        dtype = _get_dtype(self.use_bfloat16)

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='output_vocab_embeddings')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y)
        y = AddPositionEmbs(max_len=self.max_len,
                            decode=self.decode,
                            name='posembed_output')(
                                y,
                                inputs_positions=targets_positions,
                                train=train)
        y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        for lyr in range(self.dec_num_layers):
            y = EncoderDecoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn,
                dec_cross_attn_kernel_init_fn=self.
                dec_cross_attn_kernel_init_fn,
                decode=self.decode,
                name=f'encoderdecoderblock_{lyr}')(
                    y,
                    encoded,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=encoder_decoder_mask,
                    train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            y = maybe_normalize()(y)

        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])

        else:
            logits = nn.Dense(self.output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')(y)
        return logits
Ejemplo n.º 14
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 encoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.
      train: if it is training.

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Input embedding.
        if self.shared_embedding is None:
            input_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='input_vocab_embeddings')
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(max_len=self.max_len,
                            decode=False,
                            name='posembed_input')(
                                x,
                                inputs_positions=inputs_positions,
                                train=train)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        if self.use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input encoder.
        for lyr in range(self.enc_num_layers):
            x = Encoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn,
                name=f'encoderblock_{lyr}')(x,
                                            encoder_mask=encoder_mask,
                                            train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            x = maybe_normalize()(x)
        return x