def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) preshift_programs = programs # Save pre-shifted programs for padding mask. if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(preshift_programs > 0, preshift_programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def setup(self): self.embed_dim = self.config.hidden_size self.num_heads = self.config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads assert ( self.head_dim * self.num_heads == self.embed_dim ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." self.scale = self.head_dim**-0.5 self.dropout = self.config.attention_dropout self.k_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( 0.01, dtype=self.dtype)) self.v_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( 0.01, dtype=self.dtype)) self.q_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( 0.01, dtype=self.dtype)) self.out_proj = nn.Dense(self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( 0.01, dtype=self.dtype)) self.causal = isinstance(self.config, CLIPTextConfig) if self.causal: self.causal_mask = make_causal_mask( jnp.ones((1, self.config.max_position_embeddings), dtype="i4"))
def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.rotary_dim = config.rotary_dim dense = partial( nn.Dense, self.embed_dim, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) self.causal_mask = make_causal_mask(jnp.ones( (1, config.max_position_embeddings), dtype="bool"), dtype="bool") pos_embd_dim = self.rotary_dim or self.embed_dim self.embed_positions = create_sinusoidal_positions( config.max_position_embeddings, pos_embd_dim)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = flatten_num_io_dim(encoded) flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config # Allow for decoding without num_partial dimension for beam search. # programs shape == [batch_size, (num_partial), length] assert programs.ndim in [2, 3], ('Number of program dimensions should be' '2 or 3, but it is: %d' % programs.ndim) assert encoded.ndim == programs.ndim + 2 # Collapse num_io dimension. num_io_axis = 1 if programs.ndim == 2 else 2 flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask, axis=num_io_axis) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and " f"`num_heads`: {self.num_heads}).") self.attn_dropout = nn.Dropout(config.attention_dropout) self.resid_dropout = nn.Dropout(config.resid_dropout) dense = partial( nn.Dense, self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense( use_bias=False), dense(use_bias=False) self.out_proj = dense() self.causal_mask = make_causal_mask(jnp.ones( (1, config.max_position_embeddings), dtype="bool"), dtype="bool") if self.attention_type == "local": self.causal_mask = self.causal_mask ^ jnp.tril( self.causal_mask, -config.window_size)
def setup(self) -> None: self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} " "and `num_heads`: {self.num_heads}).") dense = partial( nn.Dense, self.embed_dim, use_bias=self.bias, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.dropout_layer = nn.Dropout(rate=self.dropout) if self.causal: self.causal_mask = make_causal_mask(jnp.ones( (1, self.config.max_position_embeddings), dtype="bool"), dtype="bool")
def decode( self, encoded, inputs, # only needed for masks targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None): """Applies Transformer decoder-branch on encoded-input and target. Args: encoded: encoded input data from encoder. inputs: input data (only needed for masking). targets: target data. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding only a special encoder-decoder mask is used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype), nn.make_causal_mask(targets, dtype=config.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=config.dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=config.dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) return logits.astype(self.config.dtype)
def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype) self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype) self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) self.causal_mask = make_causal_mask(jnp.ones( (1, config.max_position_embeddings), dtype="bool"), dtype="bool")
def decode(self, programs, latents, encoded, latents_padding_mask, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert latents.ndim == 3, ('Number of latents dimensions should be 3,' ' but it is: %d' % latents.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = models.flatten_num_io_dim( encoded_padding_mask) latents = self.latent_pos_emb(latents) # Concatenate the i/o encoding and latents together. flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None latent_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) latent_decoder_mask = nn.make_attention_mask(programs > 0, latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def get_attention_masks(self, inputs, targets): cfg = self.config if cfg.decode: decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), nn.make_causal_mask(targets, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=cfg.dtype) return decoder_mask, encoder_decoder_mask
def decode(self, encoded, inputs, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, train=False): # Make padding attention masks. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 if self.should_decode: # For fast autoregressive decoding, only a special encoder-decoder mask is # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), nn.make_causal_mask(targets, dtype=dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) return logits
def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies TransformerLM on the inputs. Args: inputs: target data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype), nn.make_causal_mask(inputs, dtype=config.dtype)) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = Decoder(config=config, shared_embedding=None, name='decoder')( inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, decoder_mask=decoder_mask, encoder_decoder_mask=None) return logits.astype(self.config.dtype)
def __call__(self, inputs, temb, train, context, permutations): """Applies Transformer model on the inputs. Args: inputs: Input data. temb: The time embedding. train: Is the model training? context: A context to condition on. permutations: A batch of permutations that specifies generation order. Returns: Output of a transformer decoder. """ cfg = self.config assert inputs.ndim == 2 # (batch, len) deterministic = not train # Permutations give the permutation order, for XLNet style training only. It # is important that permutations are applied _before shifting_. For this # reason, we also have to deal with the positional embeddings seperately # at a later point. if permutations is not None: assert cfg.is_causal assert permutations.shape == inputs.shape # Use the permutations to act on the inputs. inputs = util_fns.batch_permute(inputs, permutations) # Target Embedding embedding_layer = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) # Concatenate context if available. if context is not None: assert cfg.context_length == context.shape[ 1], f'{cfg.context_length} != {context.shape[1]} for {context.shape}' inputs = jnp.concatenate([context, inputs], axis=1) y = inputs.astype('int32') if cfg.is_causal: logging.info('Using causal Transformer') decoder_mask = nn.make_causal_mask(inputs, dtype=cfg.dtype) else: logging.info('Using fully connected (non-causal) Transformer') decoder_mask = None if cfg.is_causal: y = shift_inputs(y) y = embedding_layer(y) y = AddPositionEmbs(config=cfg, name='add_posemb')(y, permutations) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')( y, temb, deterministic, decoder_mask=decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) logits = nn.Dense(cfg.output_vocab_size, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) if context is not None: # Take only predictions for inputs, not context. logits = logits[:, cfg.context_length:] if permutations is not None: assert cfg.is_causal # Apply the inverse permutation to the logits. inv_permutations = util_fns.compute_batch_inverse_permute( permutations) logits = util_fns.batch_permute(logits, inv_permutations) return logits
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config.base_config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. decoder_mask = None decoder_relative_position = None # Relative positions. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: attention_mask_type = self.config.attention_mask_type if attention_mask_type == 'baseline': decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: if attention_mask_type == 'bos_to_bos': # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) elif attention_mask_type == 'bos_to_last': # BOS tokens attend to all last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = shift_left(bos_mask) elif attention_mask_type == 'bos_to_bos_and_last': # BOS tokens attend to all previous BOS + last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = jnp.logical_or(bos_mask, shift_left(bos_mask)) elif attention_mask_type == 'bos_full_attention': # BOS tokens attend to all previous tokens, including program tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: raise ValueError( 'Unhandled attention_mask_type: {}'.format( attention_mask_type)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) if self.config.bos_special_attention: # Make custom relative positions where BOS are separately indexed. decoder_relative_position = make_relative_position( programs) decoder_partial_relative_position = ( make_partial_program_relative_position( programs, bos_token=cfg.bos_token)) decoder_relative_position = jnp.where( (programs == cfg.bos_token)[Ellipsis, None], decoder_partial_relative_position, decoder_relative_position) else: decoder_relative_position = None encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position)
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)