def decode( self, encoded, inputs, # only needed for masks targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None): """Applies Transformer decoder-branch on encoded-input and target. Args: encoded: encoded input data from encoder. inputs: input data (only needed for masking). targets: target data. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding only a special encoder-decoder mask is used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype), nn.make_causal_mask(targets, dtype=config.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=config.dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=config.dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) return logits.astype(self.config.dtype)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) preshift_programs = programs # Save pre-shifted programs for padding mask. if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(preshift_programs > 0, preshift_programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = flatten_num_io_dim(encoded) flat_encoded_padding_mask = flatten_num_io_dim(encoded_padding_mask) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer encoder-branch on the inputs. Args: inputs: input data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: encoded feature array from the transformer encoder. """ cfg = self.config # Make padding attention mask. encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=cfg.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype)) return self.encoder(inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config # Allow for decoding without num_partial dimension for beam search. # programs shape == [batch_size, (num_partial), length] assert programs.ndim in [2, 3], ('Number of program dimensions should be' '2 or 3, but it is: %d' % programs.ndim) assert encoded.ndim == programs.ndim + 2 # Collapse num_io dimension. num_io_axis = 1 if programs.ndim == 2 else 2 flat_encoded = base_models.flatten_num_io_dim(encoded, axis=num_io_axis) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask, axis=num_io_axis) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder( programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask
def decode(self, encoded, inputs, targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, train=False): # Make padding attention masks. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 if self.should_decode: # For fast autoregressive decoding, only a special encoder-decoder mask is # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), nn.make_causal_mask(targets, dtype=dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype)) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask(targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) logits = self.decoder(encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) return logits
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: query_length, key_length = query.shape[1], key.shape[1] causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] if attention_mask is not None and causal_attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs
def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies TransformerLM on the inputs. Args: inputs: target data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype), nn.make_causal_mask(inputs, dtype=config.dtype)) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype)) logits = Decoder(config=config, shared_embedding=None, name='decoder')( inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, decoder_mask=decoder_mask, encoder_decoder_mask=None) return logits.astype(self.config.dtype)
def decode(self, programs, latents, encoded, latents_padding_mask, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert latents.ndim == 3, ('Number of latents dimensions should be 3,' ' but it is: %d' % latents.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = models.flatten_num_io_dim( encoded_padding_mask) latents = self.latent_pos_emb(latents) # Concatenate the i/o encoding and latents together. flat_encoded = jnp.concatenate([flat_encoded, latents], axis=1) # Make attention masks. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. decoder_mask = None latent_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) latent_decoder_mask = nn.make_attention_mask(programs > 0, latents_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) encoder_decoder_mask = jnp.concatenate( [encoder_decoder_mask, latent_decoder_mask], axis=-1) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask)
def get_attention_masks(self, inputs, targets): cfg = self.config if cfg.decode: decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype), nn.make_causal_mask(targets, dtype=cfg.dtype)) encoder_decoder_mask = nn.make_attention_mask(targets > 0, inputs > 0, dtype=cfg.dtype) return decoder_mask, encoder_decoder_mask
def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, train=False): # Make padding attention mask. encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=inputs.dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=inputs_segmentation.dtype)) encoded = self.encoder(inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, train=train) return encoded
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): qkv_out = self.c_attn(hidden_states) query, key, value = jnp.split(qkv_out, 3, axis=2) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) # usual dot product attention attn_output = dot_product_attention( query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) # TODO: at the moment it's not possible to retrieve attn_weights from # dot_product_attention, but should be in the future -> add functionality then return (attn_output, )
def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[ 1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights
def __call__( self, hidden_states, key_value_states: Optional[jnp.ndarray] = None, attention_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] if not is_cross_attention: qkv_out = self.c_attn(hidden_states) query, key, value = jnp.split(qkv_out, 3, axis=2) else: q_out = self.q_attn(hidden_states) (query, ) = jnp.split(q_out, 1, axis=2) kv_out = self.c_attn(key_value_states) key, value = jnp.split(kv_out, 2, axis=2) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) query_length, key_length = query.shape[1], key.shape[1] if self.causal: if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)
def __call__( self, hidden_states, attention_mask, position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) sincos = jnp.take(self.embed_positions, position_ids, axis=0) sincos = jnp.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, :self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim:] q_rot = query[:, :, :, :self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim:] k_rot = apply_rotary_pos_emb(k_rot, sincos) q_rot = apply_rotary_pos_emb(q_rot, sincos) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: key = apply_rotary_pos_emb(key, sincos) query = apply_rotary_pos_emb(query, sincos) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def decode(self, programs, encoded, encoded_padding_mask): """Applies decoder on programs and encoded specification.""" cfg = self.config.base_config assert programs.ndim == 2, ('Number of program dimensions should be 2,' ' but it is: %d' % programs.ndim) assert encoded.ndim == 4, ('Number of encoded dimensions should be 4,' ' but it is: %d' % encoded.ndim) # Collapse num_io dimension flat_encoded = base_models.flatten_num_io_dim(encoded) flat_encoded_padding_mask = base_models.flatten_num_io_dim( encoded_padding_mask) if cfg.shift: programs = base_models.shift_right(programs, cfg.bos_token) # Make attention masks. decoder_mask = None decoder_relative_position = None # Relative positions. if cfg.decode: # For fast decode with caching, programs shape == [batch_size, 1] and # cfg.shift = False, cfg.decode = True. # TODO(jxihong): Fast decoding currently does not work with new attention. encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(programs), flat_encoded_padding_mask, dtype=cfg.dtype) else: attention_mask_type = self.config.attention_mask_type if attention_mask_type == 'baseline': decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: if attention_mask_type == 'bos_to_bos': # BOS tokens attend to all previous BOS tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) elif attention_mask_type == 'bos_to_last': # BOS tokens attend to all last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = shift_left(bos_mask) elif attention_mask_type == 'bos_to_bos_and_last': # BOS tokens attend to all previous BOS + last partial program tokens. bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs == cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) # Shift bos mask to left to get all previous last partial program # tokens. decoder_bos_mask = jnp.logical_or(bos_mask, shift_left(bos_mask)) elif attention_mask_type == 'bos_full_attention': # BOS tokens attend to all previous tokens, including program tokens. decoder_bos_mask = nn.combine_masks( nn.make_attention_mask(programs == cfg.bos_token, programs > 0, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) else: raise ValueError( 'Unhandled attention_mask_type: {}'.format( attention_mask_type)) # Program tokens attend to all previous tokens in partial program. decoder_partial_mask = nn.combine_masks( make_partial_program_mask(programs, bos_token=cfg.bos_token, dtype=cfg.dtype), nn.make_causal_mask(programs, dtype=cfg.dtype)) decoder_mask = nn.combine_masks( nn.make_attention_mask(programs > 0, programs > 0, dtype=cfg.dtype), jnp.logical_or(decoder_bos_mask, decoder_partial_mask)) if self.config.bos_special_attention: # Make custom relative positions where BOS are separately indexed. decoder_relative_position = make_relative_position( programs) decoder_partial_relative_position = ( make_partial_program_relative_position( programs, bos_token=cfg.bos_token)) decoder_relative_position = jnp.where( (programs == cfg.bos_token)[Ellipsis, None], decoder_partial_relative_position, decoder_relative_position) else: decoder_relative_position = None encoder_decoder_mask = nn.make_attention_mask( programs > 0, flat_encoded_padding_mask, dtype=cfg.dtype) return self.decoder(programs, flat_encoded, decoder_mask, encoder_decoder_mask, decoder_relative_position)