def __call__( self, hidden_states, attention_mask, layer_head_mask, deterministic=True, output_attentions: bool = False, ): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) value_states = self.value( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) key_states = self.key( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) # Mask heads if we want to if layer_head_mask is not None: attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1, )) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) value_states = self.value(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) key_states = self.key(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: query_length, key_length = query.shape[1], key.shape[1] causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] if attention_mask is not None and causal_attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs
def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # get query proj query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) if attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights
def __call__(self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) value_states = self.value( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) key_states = self.key( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attention_bias = jnp.array(0.0, dtype=self.dtype) # Add relative position bias if present. if self.relative_position_bias is not None: attention_bias = jnp.expand_dims(self.relative_position_bias(), 0) attention_bias = attention_bias.astype(query_states.dtype) # Add shared relative position bias if provided. if relative_position_bias is not None: attention_bias = attention_bias + relative_position_bias.astype( attention_bias.dtype) attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1, )) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def __call__( self, hidden_states, attention_mask, position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) sincos = jnp.take(self.embed_positions, position_ids, axis=0) sincos = jnp.split(sincos, 2, axis=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, :self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim:] q_rot = query[:, :, :, :self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim:] k_rot = apply_rotary_pos_emb(k_rot, sincos) q_rot = apply_rotary_pos_emb(q_rot, sincos) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: key = apply_rotary_pos_emb(key, sincos) query = apply_rotary_pos_emb(query, sincos) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e9).astype(self.dtype), ) # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[ 1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights
def __call__( self, hidden_states, key_value_states: Optional[jnp.ndarray] = None, attention_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] if not is_cross_attention: qkv_out = self.c_attn(hidden_states) query, key, value = jnp.split(qkv_out, 3, axis=2) else: q_out = self.q_attn(hidden_states) (query, ) = jnp.split(q_out, 1, axis=2) kv_out = self.c_attn(key_value_states) key, value = jnp.split(kv_out, 2, axis=2) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) query_length, key_length = query.shape[1], key.shape[1] if self.causal: if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs