def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Creates a funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: upsampling if set to True. Returns: Funnel mask. """ if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert queries_len == 1 mask = jnp.arange( self._max_len) <= (self.state // self._total_kv_pooling) mask = jnp.reshape(mask, (1, 1, 1, self._max_len)) mask = jnp.repeat(mask, batch_size, axis=0) self.state += self._n_raw_tokens_generated return mask if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def forward(self, inputs): """Returns attention-computed activations. Args: inputs: A (queries, keys, values) tuple. """ q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def forward(self, inputs): q, k, v = inputs if self._mode == 'predict': self.state = _fast_inference_update_state(inputs, self.state) (k, v, mask, _) = self.state else: mask_size = q.shape[-2] # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.backend_name() == 'jax': mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) else: mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0) res, dots = DotProductAttention(q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=self.rng) if self._mode == 'viz': self.state = dots return res
def _causal_mask(length): # Not all backends define jnp.tril. However, using np.tril is inefficient # in that it creates a large global constant. TODO(kitaev): try to find an # alternative that works across all backends. if fastmath.is_backend(fastmath.Backend.JAX): return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0) else: return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
def dot_product_self_attention(q, k, v): """ Masked dot product self attention. Args: q (jax.interpreters.xla.DeviceArray): queries. k (jax.interpreters.xla.DeviceArray): keys. v (jax.interpreters.xla.DeviceArray): values. Returns: jax.interpreters.xla.DeviceArray: masked dot product self attention tensor. """ mask_size = q.shape[-2] # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size) mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)
def forward(self, inputs): inputs_len = inputs.shape[1] if self._mode == 'predict': # We cannot generate more than one token because it contradicts # all autoregressive properties assert inputs_len == 1 current_token, sequence_length = calc_predict_next_token_index( self.state, self._total_kv_pooling, self._max_len, self._chunk_len, self._chunk_offset) mask = jnp.arange(sequence_length) <= current_token mask = jnp.reshape(mask, (1, sequence_length)) self.state += self._n_raw_tokens_generated return mask if self._chunk_len is not None: return jnp.tril( jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_)) return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
def dot_product_self_attention(q, k, v): """ Masked dot product self attention. Args: q (jax.interpreters.xla.DeviceArray): queries. k (jax.interpreters.xla.DeviceArray): keys. v (jax.interpreters.xla.DeviceArray): values. Returns: jax.interpreters.xla.DeviceArray: masked dot product self attention tensor. """ # for causal attention: (Q. Kt) + M # mask size should be Lk x Lq mask_size = q.shape[-2] # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size) # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_ # Use jnp.tril() - Lower triangle of an array and jnp.ones() mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)
def dot_product_self_attention(q, k, v): mask_size = q.shape[-2] mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0) return DotProductAttention(q, k, v, mask)