def __call__(self, x): B, H, W, C = x.shape # pylint: disable=invalid-name,unused-variable assert C % self.num_heads == 0 head_dim = C // self.num_heads h = Normalize(name='norm')(x) assert h.shape == (B, H, W, C) h = h.reshape(B, H * W, C) q = nn.DenseGeneral(features=(self.num_heads, head_dim), name='q')(h) k = nn.DenseGeneral(features=(self.num_heads, head_dim), name='k')(h) v = nn.DenseGeneral(features=(self.num_heads, head_dim), name='v')(h) assert q.shape == k.shape == v.shape == (B, H * W, self.num_heads, head_dim) h = nn.dot_product_attention(query=q, key=k, value=v) assert h.shape == (B, H * W, self.num_heads, head_dim) h = nn.DenseGeneral( features=C, axis=(-2, -1), kernel_init=nn.initializers.zeros, name='proj_out')( h) assert h.shape == (B, H * W, C) h = h.reshape(B, H, W, C) assert h.shape == x.shape return x + h
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) value_states = self.value( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) key_states = self.key( hidden_states).reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_output = dot_product_attention( query_states, key_states, value_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) outputs = (attn_output.reshape(attn_output.shape[:2] + (-1, )), ) # TODO: at the moment it's not possible to retrieve attn_weights from # dot_product_attention, but should be in the future -> add functionality then return outputs
def __call__(self, language, vision, hidden): input_q = jnp.concatenate([language, hidden], axis=-1) query = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="query")(input_q) key = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="key")(language) value = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="memory_value")(vision) x = nn.dot_product_attention(query, key, value) out = nn.DenseGeneral(features=self.out_features, axis=(-2, -1), name="out")(x) return out
def __call__(self, hidden_states, attention_mask, deterministic=True): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) value_states = self.value(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) key_states = self.key(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout_rate > 0.0: dropout_rng = self.make_rng("dropout") attn_output = dot_product_attention( query_states, key_states, value_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) return attn_output.reshape(attn_output.shape[:2] + (-1,))
def __call__(self, x): B, H, W, C = x.shape # pylint: disable=invalid-name,unused-variable if self.head_dim is None: assert self.num_heads is not None assert C % self.num_heads == 0 num_heads = self.num_heads head_dim = C // num_heads else: assert self.num_heads is None assert C % self.head_dim == 0 head_dim = self.head_dim num_heads = C // head_dim h = Normalize(name='norm')(x) assert h.shape == (B, H, W, C) h = h.reshape(B, H * W, C) q = nn.DenseGeneral(features=(num_heads, head_dim), name='q')(h) k = nn.DenseGeneral(features=(num_heads, head_dim), name='k')(h) v = nn.DenseGeneral(features=(num_heads, head_dim), name='v')(h) assert q.shape == k.shape == v.shape == (B, H * W, num_heads, head_dim) h = nn.dot_product_attention(query=q, key=k, value=v) assert h.shape == (B, H * W, num_heads, head_dim) h = nn.DenseGeneral( features=C, axis=(-2, -1), kernel_init=nn.initializers.zeros, name='proj_out')(h) assert h.shape == (B, H * W, C) h = h.reshape(B, H, W, C) assert h.shape == x.shape logging.info( '%s: x=%r num_heads=%d head_dim=%d', self.name, x.shape, num_heads, head_dim) return x + h
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): qkv_out = self.c_attn(hidden_states) query, key, value = jnp.split(qkv_out, 3, axis=2) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) # usual dot product attention attn_output = dot_product_attention( query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) # TODO: at the moment it's not possible to retrieve attn_weights from # dot_product_attention, but should be in the future -> add functionality then return (attn_output, )