class DeepViTConfig: num_classes: int = 1000 depth: int = 32 mlp_dim: int = 1224 token_dim: int = 64 emb_dim: int = 408 num_heads: int = 12 dim_head: int = 32 shared_theta: bool = True activation_fn: ModuleDef = nn.gelu dtype: jnp.dtype = jnp.float32 precision: Any = jax.lax.Precision.DEFAULT kernel_init: Callable = initializers.xavier_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) posemb_init: Callable = initializers.normal(stddev=0.02)
class TNTConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" num_classes: int = 1000 depth: int = 12 image_size: int = 224 patch_size: int = 16 transformed_patch_size: int = 4 inner_dim: int = 40 inner_heads: int = 4 inner_dim_head: int = 64 inner_r: int = 4 outer_dim: int = 640 outer_heads: int = 10 outer_dim_head: int = 64 outer_r: int = 4 dtype: Any = jnp.float32 kernel_init: Callable = initializers.xavier_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) posemb_init: Callable = initializers.normal(stddev=0.02)
class Transformer(nn.Module): """Transformer Model for sequence to sequence translation. vocab_size: size of the input vocabulary. output_vocab_size: size of the output vocabulary. If None, the output vocabulary size is assumed to be the same as vocab_size. share_embeddings: bool: share embedding layer for inputs and targets. logits_via_embedding: bool: whether final logit transform shares embedding weights. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding. num_heads: number of heads. enc_num_layers: number of encoder layers. dec_num_layers: number of decoder layers. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. max_len: maximum length. dropout_rate: dropout rate. attention_dropout_rate: dropout rate for attention weights. normalizer: One of 'batch_norm', 'layer_norm', 'none' enc_self_attn_kernel_init_fn: initializer for encoder's self attention matrices. dec_self_attn_kernel_init_fn: initializer for decoder's self attention matrices. dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention matrices. decode: whether to use an autoregressive cache. """ vocab_size: Optional[int] = None output_vocab_size: Optional[int] = None share_embeddings: bool = False logits_via_embedding: bool = False use_bfloat16: bool = False emb_dim: int = 512 num_heads: int = 8 enc_num_layers: int = 6 dec_num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 2048 dropout_rate: float = 0.3 attention_dropout_rate: float = 0.3 normalizer: str = 'layer_norm' enc_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long should_decode: bool = False def setup(self): if self.share_embeddings: if self.output_vocab_size is not None: assert self.output_vocab_size == self.vocab_size, ( "can't share embedding with different vocab sizes.") self.shared_embedding = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='VocabEmbeddings') else: self.shared_embedding = None self.encoder = Encoder( vocab_size=self.vocab_size, shared_embedding=self.shared_embedding, use_bfloat16=self.use_bfloat16, emb_dim=self.emb_dim, num_heads=self.num_heads, enc_num_layers=self.enc_num_layers, qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, max_len=self.max_len, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, name='encoder') self.decoder = Decoder( output_vocab_size=self.output_vocab_size, shared_embedding=self.shared_embedding, logits_via_embedding=self.logits_via_embedding, use_bfloat16=self.use_bfloat16, emb_dim=self.emb_dim, num_heads=self.num_heads, dec_num_layers=self.dec_num_layers, qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, max_len=self.max_len, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, dec_cross_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, decode=self.should_decode, name='decoder') @nn.compact def __call__(self, inputs, targets, inputs_positions=None, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, train=False): """Applies Transformer model on the inputs. Args: inputs: input data. targets: target data. inputs_positions: input subsequence positions for packed examples. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. train: whether it is training. Returns: Output: <float>[batch_size, target_sequence_length, qkv_dim] """ encoded = self.encode(inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, train=train) logits = self.decode( encoded, inputs, # only used for masks targets, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, train=train) return logits.astype(jnp.float32) if self.use_bfloat16 else logits # The following two methods allow us to run the trained Transformer in # two parts during fast decoding. First, we call the encoder branch to # encode the inputs, then we call the decoder branch while using a # cache object for iteratively storing keys and values during the decoding # process. def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, train=False): # Make padding attention mask. encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=inputs.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=inputs_segmentation.dtype)) encoded = self.encoder(inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, train=train) return encoded def decode(self, encoded, inputs, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, train=False): # Make padding attention masks. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 if self.should_decode: # For fast autoregressive decoding, only a special encoder-decoder mask is # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), nn.make_causal_mask(targets, dtype=dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) return logits
class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. output_vocab_size: size of the vocabulary. shared_embedding: a shared embedding layer to use. logits_via_embedding: bool: whether final logit transform shares embedding weights. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding. num_heads: number of heads. dec_num_layers: number of layers. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. max_len: maximum length. decode: whether to use an autoregressive cache. dropout_rate: dropout rate. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' attention_dropout_rate: dropout rate for attention weights. dec_self_attn_kernel_init_fn: initializer for decoder's self attention matrices. dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention matrices. """ output_vocab_size: int shared_embedding: Any = None logits_via_embedding: bool = False use_bfloat16: bool = False emb_dim: int = 512 num_heads: int = 8 dec_num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 512 decode: bool = False dropout_rate: float = 0.1 normalizer: str = 'layer_norm' attention_dropout_rate: float = 0.1 dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long @nn.compact def __call__(self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. train: whether it is training. Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) dtype = _get_dtype(self.use_bfloat16) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.output_vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='output_vocab_embeddings') else: output_embed = self.shared_embedding y = targets.astype('int32') if not self.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, decode=self.decode, name='posembed_output')( y, inputs_positions=targets_positions, train=train) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) if self.use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder for lyr in range(self.dec_num_layers): y = EncoderDecoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, dec_cross_attn_kernel_init_fn=self. dec_cross_attn_kernel_init_fn, decode=self.decode, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) y = maybe_normalize()(y) # Decoded Logits if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense')(y) return logits
class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. vocab_size: size of the vocabulary shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads enc_num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. dropout_rate: dropout rate normalizer: One of 'batch_norm', 'layer_norm', 'none' attention_dropout_rate: dropout rate for attention weights enc_self_attn_kernel_init_fn: initializer for encoder's self attention matrices. """ vocab_size: int shared_embedding: Any = None use_bfloat16: bool = False emb_dim: int = 512 num_heads: int = 8 enc_num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 512 dropout_rate: float = 0.1 normalizer: str = 'layer_norm' attention_dropout_rate: float = 0.1 enc_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long @nn.compact def __call__(self, inputs, inputs_positions=None, encoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. train: if it is training. Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) # Input embedding. if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='input_vocab_embeddings') else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(max_len=self.max_len, decode=False, name='posembed_input')( x, inputs_positions=inputs_positions, train=train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) if self.use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input encoder. for lyr in range(self.enc_num_layers): x = Encoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, name=f'encoderblock_{lyr}')(x, encoder_mask=encoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) x = maybe_normalize()(x) return x
class EncoderDecoder1DBlock(nn.Module): """Transformer encoder-decoder layer. Attributes: qkv_dim: Dimension of the query/key/value. mlp_dim: Dimension of the mlp on top of attention block. num_heads: Number of heads. dtype: Dtype of the computation (default: float32). dropout_rate: <float> Dropout rate. attention_dropout_rate: <float> Dropout rate for attention weights normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' dec_self_attn_kernel_init_fn: initializer for decoder's self attention matrices. dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention matrices. decode: whether to use an autoregressive cache. """ qkv_dim: int mlp_dim: int num_heads: int dtype: model_utils.Dtype = jnp.float32 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 normalizer: str = 'layer_norm' dec_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long dec_cross_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long decode: bool = False @nn.compact def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None, train=True): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. train: if it is training. Returns: output after transformer encoder-decoder block. """ # Decoder block. assert targets.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( self.normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer( self.normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(targets) x = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.dec_self_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, decode=self.decode, name='DecoderSelfAttention')( x, decoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + targets x = maybe_post_normalize()(x) # Encoder-Decoder block. y = maybe_pre_normalize()(x) y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.dec_cross_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate)(y, encoded, encoder_decoder_mask, deterministic=not train) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y + x y = maybe_post_normalize()(y) # MLP block. z = maybe_pre_normalize()(y) z = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, name='MLPBlock')(z, train=train) res = y + z return maybe_post_normalize()(res)