def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_position_embeddings self.embed_scale = math.sqrt( self.config.d_model) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 self.embed_positions = create_sinusoidal_positions( self.config.max_position_embeddings + self.offset, embed_dim) self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(self, inputs, inputs_positions=None, encoder_mask=None): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. Returns: output of a transformer encoder. """ cfg = self.config assert inputs.ndim == 2 # (batch, len) # Input Embedding if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')( x, inputs_positions=inputs_positions) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) # Input Encoder for lyr in range(cfg.num_layers): x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return encoded
def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.hidden_size self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_position_embeddings self.embed_tokens = nn.Embed( self.config.vocab_size, self.config.word_embed_proj_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) self.embed_positions = FlaxOPTLearnedPositionalEmbedding( self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) if self.config.word_embed_proj_dim != self.config.hidden_size: self.project_in = nn.Dense(self.config.hidden_size, use_bias=False) self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False) else: self.project_in = None self.project_out = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # with checkpoints that have been fine-tuned before transformers v4.20.1 # see https://github.com/facebookresearch/metaseq/pull/164 if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) else: self.final_layer_norm = None self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
def __call__(self, x): out = {} embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width) x = out['embedded'] = embedding(x) # Add posemb n, l, d = x.shape # pylint: disable=unused-variable x = x + self.param('pos_embedding', nn.initializers.normal(stddev=1 / jnp.sqrt(d)), (1, l, d), x.dtype) x = models_vit.Encoder(num_layers=self.num_layers, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=0, add_position_embedding=False)(x, train=False) x = out['pre_logits'] = x[:, -1, :] # note that we take *last* token x = out['logits'] = nn.Dense(self.num_classes, name='head')(x) return x, out
def __call__(self, inputs, outputs): """Applies Transformer model to encode the IO specification. Args: inputs: input data [batch_size, num_io, length] outputs: output data [batch_size, num_io, length2] Returns: Encoded IO data `[batch_size, num_io, length2, dim]` """ cfg = self.config # Inputs and outputs shared embeddings. embed = nn.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed') if not cfg.use_relative_attention: pos_emb = AddPositionEmbs(config=cfg, cache=False, name='posembed_io') x = inputs.astype('int32') y = outputs.astype('int32') # Make attention masks. inputs_encoder_mask = nn.make_attention_mask( x > 0, x > 0, dtype=cfg.dtype) outputs_encoder_mask = nn.make_attention_mask( y > 0, y > 0, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( y > 0, x > 0, dtype=cfg.dtype) # Embed inputs. x = embed(x) if not cfg.use_relative_attention: x = pos_emb(x) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) for lyr in range(cfg.num_layers): x = EncoderBlock( # Attend to inputs. config=cfg, bidirectional_attention=True, num_relative_position_buckets=( cfg.num_input_relative_position_buckets), max_distance=cfg.max_input_distance, name=f'encoderblock_{lyr}')(x, inputs_encoder_mask) x = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) # Embed outputs. y = embed(y) if not cfg.use_relative_attention: y = pos_emb(y) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) encode_decoder_cfg = cfg.replace(decode=False) for lyr in range(cfg.num_layers): y = EncoderDecoderBlock( # Double attend to inputs and outputs. config=encode_decoder_cfg, bidirectional_attention=True, num_relative_position_buckets=( cfg.num_output_relative_position_buckets), max_distance=cfg.max_output_distance, relative_cross_attention=cfg.use_relative_attention, bidirectional_cross_attention=True, num_relative_position_buckets_cross_attention=( cfg.num_input_cross_output_relative_position_buckets), max_distance_cross_attention=cfg.max_input_cross_output_distance, name=f'encoderdecoderblock_{lyr}')( y, x, outputs_encoder_mask, encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) return y
def __call__(self, targets, encoded, decoder_mask = None, encoder_decoder_mask = None, decoder_relative_position = None, encoder_decoder_relative_position = None): """Applies Transformer to decode the targets. Args: targets: target outputs. encoded: encoded input data from encoder [batch, ..., length, mlp_dim]. decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask decoder_relative_position: decoder relative positions tensor `[batch_sizes..., length2, length2]' encoder_decoder_relative_position: encoder-decoder relative tensor `[batch_sizes..., length2, length]' Returns: output of a transformer decoder. """ cfg = self.config assert encoded.ndim == targets.ndim + 1 output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed_output') heads = dict() y = targets.astype('int32') if cfg.shift: y = shift_right(y, cfg.bos_token) y = output_embed(y) if not cfg.use_relative_attention: y = AddPositionEmbs(config=cfg, cache=cfg.decode, name='posembed_output')(y) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoderBlock( config=cfg, bidirectional_attention=False, num_relative_position_buckets=( cfg.num_program_relative_position_buckets), max_distance=cfg.max_program_distance, # relative_cross_attention=cfg.use_relative_attention, relative_cross_attention=False, bidirectional_cross_attention=True, num_relative_position_buckets_cross_attention=( cfg.num_program_cross_embed_relative_position_buckets), max_distance_cross_attention=cfg.max_program_cross_embed_distance, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position, encoder_decoder_relative_position) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) heads['output_emb'] = y * ( jnp.where(targets > 0, 1, 0).astype(jnp.float32)[Ellipsis, None]) logits = nn.Dense( cfg.output_vocab_size, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) heads['logits'] = logits if cfg.output_head: return heads[cfg.output_head] else: return heads # Return both output embeddings and logits.
def __call__(self, inputs, temb, train, context, permutations): """Applies Transformer model on the inputs. Args: inputs: Input data. temb: The time embedding. train: Is the model training? context: A context to condition on. permutations: A batch of permutations that specifies generation order. Returns: Output of a transformer decoder. """ cfg = self.config assert inputs.ndim == 2 # (batch, len) deterministic = not train # Permutations give the permutation order, for XLNet style training only. It # is important that permutations are applied _before shifting_. For this # reason, we also have to deal with the positional embeddings seperately # at a later point. if permutations is not None: assert cfg.is_causal assert permutations.shape == inputs.shape # Use the permutations to act on the inputs. inputs = util_fns.batch_permute(inputs, permutations) # Target Embedding embedding_layer = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) # Concatenate context if available. if context is not None: assert cfg.context_length == context.shape[ 1], f'{cfg.context_length} != {context.shape[1]} for {context.shape}' inputs = jnp.concatenate([context, inputs], axis=1) y = inputs.astype('int32') if cfg.is_causal: logging.info('Using causal Transformer') decoder_mask = nn.make_causal_mask(inputs, dtype=cfg.dtype) else: logging.info('Using fully connected (non-causal) Transformer') decoder_mask = None if cfg.is_causal: y = shift_inputs(y) y = embedding_layer(y) y = AddPositionEmbs(config=cfg, name='add_posemb')(y, permutations) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')( y, temb, deterministic, decoder_mask=decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) logits = nn.Dense(cfg.output_vocab_size, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) if context is not None: # Take only predictions for inputs, not context. logits = logits[:, cfg.context_length:] if permutations is not None: assert cfg.is_causal # Apply the inverse permutation to the logits. inv_permutations = util_fns.compute_batch_inverse_permute( permutations) logits = util_fns.batch_permute(logits, inv_permutations) return logits
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer to decode the targets. Args: targets: target outputs. encoded: encoded input data from encoder [batch, ..., length, mlp_dim]. decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask Returns: output of a transformer decoder. """ cfg = self.config assert encoded.ndim == targets.ndim + 1 output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed_output') if cfg.use_relative_attention: attention_fn = functools.partial( relative_attention.RelativeMultiHeadDotProductAttention, num_relative_position_buckets=cfg. num_relative_position_buckets, causal=False) self_attention_fn = functools.partial( relative_attention.RelativeSelfAttention, num_relative_position_buckets=cfg. num_relative_position_buckets, causal=True) else: attention_fn = nn.MultiHeadDotProductAttention self_attention_fn = nn.SelfAttention heads = dict() y = targets.astype('int32') if cfg.shift: y = shift_right(y, cfg.bos_token) y = output_embed(y) if not cfg.use_relative_attention: y = AddPositionEmbs(config=cfg, cache=cfg.decode, name='posembed_output')(y) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoderBlock( config=cfg, dot_product_attention_fn=attention_fn, self_attention_fn=self_attention_fn, name=f'encoderdecoderblock_{lyr}')(y, encoded, decoder_mask, encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) heads['output_emb'] = y * (jnp.where(targets > 0, 1, 0).astype( jnp.float32)[Ellipsis, None]) logits = nn.Dense(cfg.output_vocab_size, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) heads['logits'] = logits if cfg.output_head: return heads[cfg.output_head] else: return heads # Return both output embeddings and logits.
def __call__(self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output of a transformer decoder. """ cfg = self.config assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = targets.astype('int32') if not cfg.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')( y, inputs_positions=targets_positions) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoder1DBlock( config=cfg, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) # Decoded Logits if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(cfg.output_vocab_size, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) return logits
def test_embed_hash(self): self.assertEqual(hash(nn.Embed(2, 3)), hash(nn.Embed(2, 3))) self.assertNotEqual(hash(nn.Embed(3, 4)), hash(nn.Embed(2, 3)))
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)
def setup(self): self.embed = nn.Embed( num_embeddings=self.num_embeddings, features=self.features, embedding_init=self.embedding_init, )
def __call__(self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. train: whether it is training. Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) dtype = _get_dtype(self.use_bfloat16) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.output_vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='output_vocab_embeddings') else: output_embed = self.shared_embedding y = targets.astype('int32') if not self.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, decode=self.decode, name='posembed_output')( y, inputs_positions=targets_positions, train=train) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) if self.use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder for lyr in range(self.dec_num_layers): y = EncoderDecoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, dec_cross_attn_kernel_init_fn=self. dec_cross_attn_kernel_init_fn, decode=self.decode, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) y = maybe_normalize()(y) # Decoded Logits if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense')(y) return logits
def __call__(self, inputs, inputs_positions=None, encoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. train: if it is training. Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) # Input embedding. if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='input_vocab_embeddings') else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(max_len=self.max_len, decode=False, name='posembed_input')( x, inputs_positions=inputs_positions, train=train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) if self.use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input encoder. for lyr in range(self.enc_num_layers): x = Encoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, name=f'encoderblock_{lyr}')(x, encoder_mask=encoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) x = maybe_normalize()(x) return x