Esempio n. 1
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              shift=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              cache=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output of a transformer decoder.
    """
        padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[...,
                                                                       None]
        assert inputs.ndim == 2  # (batch, len)
        x = inputs
        if shift:
            x = common_layers.shift_right(x)
        x = x.astype('int32')
        x = common_layers.Embed(x,
                                num_embeddings=vocab_size,
                                features=emb_dim,
                                name='embed')
        x = common_layers.AddPositionEmbs(
            x,
            max_len=max_len,
            posemb_init=common_layers.sinusoidal_init(max_len=max_len),
            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        for _ in range(num_layers):
            x = BigBirdBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=True,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=cache,
            )
        x = nn.LayerNorm(x)
        logits = nn.Dense(x,
                          vocab_size,
                          kernel_init=nn.initializers.xavier_uniform(),
                          bias_init=nn.initializers.normal(stddev=1e-6))
        return logits
Esempio n. 2
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='CLS',
              num_classes=10,
              block_size=_DEFAULT_BLOCK_SIZE):
        """Applies BigBird transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.
      block_size: Size of attention blocks.

    Returns:
      output of a transformer encoder or logits if classifier_mode is true.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)

        if classifier and classifier_pool == 'CLS':
            cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros)
            cls = jnp.tile(cls, [x.shape[0], 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
            max_len += 1
            src_padding_mask = jnp.concatenate(
                [src_padding_mask[:, :1], src_padding_mask], axis=1)

        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = BigBirdBlock(x,
                             qkv_dim=qkv_dim,
                             mlp_dim=mlp_dim,
                             num_heads=num_heads,
                             dtype=dtype,
                             padding_mask=src_padding_mask,
                             inputs_segmentation=inputs_segmentation,
                             dropout_rate=dropout_rate,
                             attention_dropout_rate=attention_dropout_rate,
                             deterministic=not train,
                             block_size=block_size,
                             connectivity_seed=lyr,
                             name=f'encoderblock_{lyr}')
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            encoded = common_layers.classifier_head(
                encoded, num_classes, mlp_dim, pooling_mode=classifier_pool)
        return encoded
Esempio n. 3
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              block_size=50,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='MEAN',
              num_classes=10):
        """Applies Local Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      block_size: int, block size.
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = SinkhornTransformerBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                name=f'encoderblock_{lyr}',
                block_size=block_size)
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            if classifier_pool == 'MEAN':
                encoded = jnp.mean(encoded, axis=1)
                encoded = nn.Dense(encoded, num_classes, name='logits')
            else:
                # TODO(yitay): Add other pooling methods.
                raise ValueError('Pooling method not supported yet.')
        return encoded