Example #1
0
class DeepViTConfig:
    num_classes: int = 1000
    depth: int = 32
    mlp_dim: int = 1224
    token_dim: int = 64
    emb_dim: int = 408
    num_heads: int = 12
    dim_head: int = 32
    shared_theta: bool = True
    activation_fn: ModuleDef = nn.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Any = jax.lax.Precision.DEFAULT
    kernel_init: Callable = initializers.xavier_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)
    posemb_init: Callable = initializers.normal(stddev=0.02)
class TNTConfig:
    """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
    num_classes: int = 1000
    depth: int = 12
    image_size: int = 224
    patch_size: int = 16
    transformed_patch_size: int = 4
    inner_dim: int = 40
    inner_heads: int = 4
    inner_dim_head: int = 64
    inner_r: int = 4
    outer_dim: int = 640
    outer_heads: int = 10
    outer_dim_head: int = 64
    outer_r: int = 4
    dtype: Any = jnp.float32
    kernel_init: Callable = initializers.xavier_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)
    posemb_init: Callable = initializers.normal(stddev=0.02)
Example #3
0
class Transformer(nn.Module):
    """Transformer Model for sequence to sequence translation.

    vocab_size: size of the input vocabulary.
    output_vocab_size: size of the output vocabulary. If None, the output
      vocabulary size is assumed to be the same as vocab_size.
    share_embeddings: bool: share embedding layer for inputs and targets.
    logits_via_embedding: bool: whether final logit transform shares embedding
      weights.
    use_bfloat16: bool: whether use bfloat16.
    emb_dim: dimension of embedding.
    num_heads: number of heads.
    enc_num_layers: number of encoder layers.
    dec_num_layers: number of decoder layers.
    qkv_dim: dimension of the query/key/value.
    mlp_dim: dimension of the mlp on top of attention block.
    max_len: maximum length.
    dropout_rate: dropout rate.
    attention_dropout_rate: dropout rate for attention weights.
    normalizer: One of 'batch_norm', 'layer_norm', 'none'
    enc_self_attn_kernel_init_fn: initializer for encoder's
      self attention matrices.
    dec_self_attn_kernel_init_fn: initializer for decoder's
      self attention matrices.
    dec_cross_attn_kernel_init_fn: initializer for decoder's
      cross attention matrices.
    decode: whether to use an autoregressive cache.
  """
    vocab_size: Optional[int] = None
    output_vocab_size: Optional[int] = None
    share_embeddings: bool = False
    logits_via_embedding: bool = False
    use_bfloat16: bool = False
    emb_dim: int = 512
    num_heads: int = 8
    enc_num_layers: int = 6
    dec_num_layers: int = 6
    qkv_dim: int = 512
    mlp_dim: int = 2048
    max_len: int = 2048
    dropout_rate: float = 0.3
    attention_dropout_rate: float = 0.3
    normalizer: str = 'layer_norm'
    enc_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    should_decode: bool = False

    def setup(self):
        if self.share_embeddings:
            if self.output_vocab_size is not None:
                assert self.output_vocab_size == self.vocab_size, (
                    "can't share embedding with different vocab sizes.")
            self.shared_embedding = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='VocabEmbeddings')
        else:
            self.shared_embedding = None

        self.encoder = Encoder(
            vocab_size=self.vocab_size,
            shared_embedding=self.shared_embedding,
            use_bfloat16=self.use_bfloat16,
            emb_dim=self.emb_dim,
            num_heads=self.num_heads,
            enc_num_layers=self.enc_num_layers,
            qkv_dim=self.qkv_dim,
            mlp_dim=self.mlp_dim,
            max_len=self.max_len,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
            normalizer=self.normalizer,
            enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn,
            name='encoder')
        self.decoder = Decoder(
            output_vocab_size=self.output_vocab_size,
            shared_embedding=self.shared_embedding,
            logits_via_embedding=self.logits_via_embedding,
            use_bfloat16=self.use_bfloat16,
            emb_dim=self.emb_dim,
            num_heads=self.num_heads,
            dec_num_layers=self.dec_num_layers,
            qkv_dim=self.qkv_dim,
            mlp_dim=self.mlp_dim,
            max_len=self.max_len,
            dropout_rate=self.dropout_rate,
            attention_dropout_rate=self.attention_dropout_rate,
            normalizer=self.normalizer,
            dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn,
            dec_cross_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn,
            decode=self.should_decode,
            name='decoder')

    @nn.compact
    def __call__(self,
                 inputs,
                 targets,
                 inputs_positions=None,
                 targets_positions=None,
                 inputs_segmentation=None,
                 targets_segmentation=None,
                 train=False):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data.
      targets: target data.
      inputs_positions: input subsequence positions for packed examples.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      train: whether it is training.

    Returns:
      Output: <float>[batch_size, target_sequence_length, qkv_dim]
    """
        encoded = self.encode(inputs,
                              inputs_positions=inputs_positions,
                              inputs_segmentation=inputs_segmentation,
                              train=train)

        logits = self.decode(
            encoded,
            inputs,  # only used for masks
            targets,
            targets_positions=targets_positions,
            inputs_segmentation=inputs_segmentation,
            targets_segmentation=targets_segmentation,
            train=train)
        return logits.astype(jnp.float32) if self.use_bfloat16 else logits

    # The following two methods allow us to run the trained Transformer in
    # two parts during fast decoding.  First, we call the encoder branch to
    # encode the inputs, then we call the decoder branch while using a
    # cache object for iteratively storing keys and values during the decoding
    # process.

    def encode(self,
               inputs,
               inputs_positions=None,
               inputs_segmentation=None,
               train=False):
        # Make padding attention mask.
        encoder_mask = nn.make_attention_mask(inputs > 0,
                                              inputs > 0,
                                              dtype=inputs.dtype)
        # Add segmentation block-diagonal attention mask if using segmented data.
        if inputs_segmentation is not None:
            encoder_mask = nn.combine_masks(
                encoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=inputs_segmentation.dtype))
        encoded = self.encoder(inputs,
                               inputs_positions=inputs_positions,
                               encoder_mask=encoder_mask,
                               train=train)

        return encoded

    def decode(self,
               encoded,
               inputs,
               targets,
               targets_positions=None,
               inputs_segmentation=None,
               targets_segmentation=None,
               train=False):
        # Make padding attention masks.
        dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32
        if self.should_decode:
            # For fast autoregressive decoding, only a special encoder-decoder mask is
            # used.
            decoder_mask = None
            encoder_decoder_mask = nn.make_attention_mask(
                jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype)
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype),
                nn.make_causal_mask(targets, dtype=dtype))
            encoder_decoder_mask = nn.make_attention_mask(targets > 0,
                                                          inputs > 0,
                                                          dtype=dtype)

        # Add segmentation block-diagonal attention masks if using segmented data.
        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       targets_segmentation,
                                       jnp.equal,
                                       dtype=dtype))
            encoder_decoder_mask = nn.combine_masks(
                encoder_decoder_mask,
                nn.make_attention_mask(targets_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        logits = self.decoder(encoded,
                              targets,
                              targets_positions=targets_positions,
                              decoder_mask=decoder_mask,
                              encoder_decoder_mask=encoder_decoder_mask,
                              train=train)
        return logits
Example #4
0
class Decoder(nn.Module):
    """Transformer Model Decoder for sequence to sequence translation.

    output_vocab_size: size of the vocabulary.
    shared_embedding: a shared embedding layer to use.
    logits_via_embedding: bool: whether final logit transform shares embedding
      weights.
    use_bfloat16: bool: whether use bfloat16.
    emb_dim: dimension of embedding.
    num_heads: number of heads.
    dec_num_layers: number of layers.
    qkv_dim: dimension of the query/key/value.
    mlp_dim: dimension of the mlp on top of attention block.
    max_len: maximum length.
    decode: whether to use an autoregressive cache.
    dropout_rate: dropout rate.
    normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
      'pre_layer_norm', 'none'
    attention_dropout_rate: dropout rate for attention weights.
    dec_self_attn_kernel_init_fn: initializer for decoder's
      self attention matrices.
    dec_cross_attn_kernel_init_fn: initializer for decoder's
      cross attention matrices.
  """
    output_vocab_size: int
    shared_embedding: Any = None
    logits_via_embedding: bool = False
    use_bfloat16: bool = False
    emb_dim: int = 512
    num_heads: int = 8
    dec_num_layers: int = 6
    qkv_dim: int = 512
    mlp_dim: int = 2048
    max_len: int = 512
    decode: bool = False
    dropout_rate: float = 0.1
    normalizer: str = 'layer_norm'
    attention_dropout_rate: float = 0.1
    dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long

    @nn.compact
    def __call__(self,
                 encoded,
                 targets,
                 targets_positions=None,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

      train: whether it is training.

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)
        dtype = _get_dtype(self.use_bfloat16)

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='output_vocab_embeddings')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y)
        y = AddPositionEmbs(max_len=self.max_len,
                            decode=self.decode,
                            name='posembed_output')(
                                y,
                                inputs_positions=targets_positions,
                                train=train)
        y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        for lyr in range(self.dec_num_layers):
            y = EncoderDecoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn,
                dec_cross_attn_kernel_init_fn=self.
                dec_cross_attn_kernel_init_fn,
                decode=self.decode,
                name=f'encoderdecoderblock_{lyr}')(
                    y,
                    encoded,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=encoder_decoder_mask,
                    train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            y = maybe_normalize()(y)

        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])

        else:
            logits = nn.Dense(self.output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')(y)
        return logits
Example #5
0
class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation.

    vocab_size: size of the vocabulary
    shared_embedding: a shared embedding layer to use.
    use_bfloat16: bool: whether use bfloat16.
    emb_dim: dimension of embedding
    num_heads: number of heads
    enc_num_layers: number of layers
    qkv_dim: dimension of the query/key/value
    mlp_dim: dimension of the mlp on top of attention block
    max_len: maximum length.
    dropout_rate: dropout rate
    normalizer: One of 'batch_norm', 'layer_norm', 'none'
    attention_dropout_rate: dropout rate for attention weights
    enc_self_attn_kernel_init_fn: initializer for encoder's
      self attention matrices.
  """
    vocab_size: int
    shared_embedding: Any = None
    use_bfloat16: bool = False
    emb_dim: int = 512
    num_heads: int = 8
    enc_num_layers: int = 6
    qkv_dim: int = 512
    mlp_dim: int = 2048
    max_len: int = 512
    dropout_rate: float = 0.1
    normalizer: str = 'layer_norm'
    attention_dropout_rate: float = 0.1
    enc_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long

    @nn.compact
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 encoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.
      train: if it is training.

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Input embedding.
        if self.shared_embedding is None:
            input_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='input_vocab_embeddings')
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(max_len=self.max_len,
                            decode=False,
                            name='posembed_input')(
                                x,
                                inputs_positions=inputs_positions,
                                train=train)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        if self.use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input encoder.
        for lyr in range(self.enc_num_layers):
            x = Encoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn,
                name=f'encoderblock_{lyr}')(x,
                                            encoder_mask=encoder_mask,
                                            train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            x = maybe_normalize()(x)
        return x
Example #6
0
class EncoderDecoder1DBlock(nn.Module):
    """Transformer encoder-decoder layer.

  Attributes:
    qkv_dim: Dimension of the query/key/value.
    mlp_dim: Dimension of the mlp on top of attention block.
    num_heads: Number of heads.
    dtype: Dtype of the computation (default: float32).
    dropout_rate: <float> Dropout rate.
    attention_dropout_rate: <float> Dropout rate for attention weights
    normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
      'pre_layer_norm', 'none'
    dec_self_attn_kernel_init_fn: initializer for decoder's
      self attention matrices.
    dec_cross_attn_kernel_init_fn: initializer for decoder's
      cross attention matrices.
    decode: whether to use an autoregressive cache.
  """
    qkv_dim: int
    mlp_dim: int
    num_heads: int
    dtype: model_utils.Dtype = jnp.float32
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1
    normalizer: str = 'layer_norm'
    dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform()  # pylint: disable=line-too-long
    decode: bool = False

    @nn.compact
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 train=True):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.
      train: if it is training.

    Returns:
      output after transformer encoder-decoder block.
    """
        # Decoder block.
        assert targets.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            maybe_post_normalize = model_utils.get_normalizer('none', train)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none', train)
            maybe_post_normalize = model_utils.get_normalizer(
                self.normalizer, train)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(targets)
        x = nn.SelfAttention(num_heads=self.num_heads,
                             dtype=self.dtype,
                             qkv_features=self.qkv_dim,
                             kernel_init=self.dec_self_attn_kernel_init_fn,
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=self.attention_dropout_rate,
                             decode=self.decode,
                             name='DecoderSelfAttention')(
                                 x, decoder_mask, deterministic=not train)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = x + targets

        x = maybe_post_normalize()(x)
        # Encoder-Decoder block.
        y = maybe_pre_normalize()(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            dtype=self.dtype,
            qkv_features=self.qkv_dim,
            kernel_init=self.dec_cross_attn_kernel_init_fn,
            bias_init=nn.initializers.normal(stddev=1e-6),
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=self.attention_dropout_rate)(y,
                                                      encoded,
                                                      encoder_decoder_mask,
                                                      deterministic=not train)

        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)
        y = y + x

        y = maybe_post_normalize()(y)
        # MLP block.
        z = maybe_pre_normalize()(y)
        z = MlpBlock(mlp_dim=self.mlp_dim,
                     dtype=self.dtype,
                     dropout_rate=self.dropout_rate,
                     name='MLPBlock')(z, train=train)

        res = y + z
        return maybe_post_normalize()(res)