def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, num_partitions=2): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if num_partitions > 1: partitions = P(1, 1, num_partitions, 1) query = with_sharding_constraint(query, partitions) key = with_sharding_constraint(key, partitions) value = with_sharding_constraint(value, partitions) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store(lambda: (key.ndim, key.shape[-2:])) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape i = cache_entry.i one_hot_indices = jax.nn.one_hot(i, cshape[3], dtype=key.dtype).reshape( (1, 1, 1, cshape[3])) key = key.transpose((0, 2, 3, 1)) key = cache_entry.key + key * one_hot_indices value = value.transpose((0, 2, 3, 1)) value = cache_entry.value + value * one_hot_indices one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key = key.transpose((0, 3, 1, 2)) value = value.transpose((0, 3, 1, 2)) cshape = (cshape[0], cshape[3], cshape[1], cshape[2]) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # create attention masks mask_components = [] if causal_mask: if cache and not self.is_initializing(): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(np.take(key.shape, attention_axis)) attn_size = np.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = dot_product_attention(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') if num_partitions > 1: x = with_sharding_constraint(x, None) return out
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, block_size=50, max_num_blocks=25, sort_activation='softmax'): """Applies multi-head sinkhorn attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. block_size: int, block size. max_num_blocks: int, max num blocks. sort_activation: str {softmax, sinkhorn, gumbel_sinkhorn} Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') assert inputs_q.ndim == 3 if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial( axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] qlength = inputs_q.shape[-2] bs = inputs_q.shape[0] kvlength = inputs_kv.shape[-2] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store(onp.array((key.ndim,) + key.shape[-2:], dtype=onp.int32)) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # block reshape before attention num_query_blocks = qlength // block_size num_kv_blocks = kvlength // block_size block_query = jnp.reshape( query, (bs, block_size, num_query_blocks, num_heads, head_dim)) block_key = jnp.reshape( key, (bs, block_size, num_kv_blocks, num_heads, head_dim)) block_value = jnp.reshape( value, (bs, block_size, num_kv_blocks, num_heads, head_dim)) if causal_mask: # causal masking needs to not have blocks with mixed information. sum_key = jnp.cumsum(block_key, axis=1) sum_key = sum_key[:, 0, :, :, :] # take first item else: sum_key = jnp.sum(block_key, axis=1) # sort net on head_dim dimensions sort_out = nn.DenseGeneral(sum_key, axis=-1, features=(max_num_blocks), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # (bs x num_key_blocks x num_heads x num_key_blocks sort_out = sort_out[:, :, :, :num_query_blocks] # simple softmax sorting first. if sort_activation == 'sinkhorn': permutation = sinkhorn_operator( jnp.reshape(sort_out, (-1, num_kv_blocks, num_query_blocks)), causal=causal_mask) permutation = jnp.reshape(permutation, (-1, num_kv_blocks, num_heads, num_query_blocks)) else: if causal_mask: block_mask = _make_causal_mask(key, attention_axis) sort_out += block_mask permutation = jax.nn.softmax(sort_out, axis=-1) sorted_key = jnp.einsum('bskhd,bnhl->bsnhd', block_key, permutation) sorted_value = jnp.einsum('bskhd,bnhl->bsnhd', block_value, permutation) # create attention masks mask_components = [] sorted_mask_components = [] if causal_mask: # TODO(yitay): Test this causal masking. if cache and not self.is_initializing(): bias_pre_shape = (1,) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: # divide padding mask into block padding_mask = jnp.reshape(padding_mask, (bs * num_query_blocks, block_size, 1)) if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim), key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim), attention_axis=attention_axis) padding_mask = jnp.reshape(padding_mask, (bs, num_query_blocks, block_size, block_size)) mask_components.append(padding_mask) sorted_padding_mask = jnp.einsum('bksj,bnhl->bnsj', padding_mask, permutation) sorted_mask_components.append(sorted_padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim), key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim), attention_axis=attention_axis, segmentation_mask=True) segmentation_mask = jnp.reshape(segmentation_mask, (bs, num_query_blocks, block_size, block_size)) mask_components.append(segmentation_mask) sorted_segmentation_mask = jnp.einsum('bksj,bnhl->bnsj', segmentation_mask, permutation) sorted_mask_components.append(sorted_segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None if sorted_mask_components: attention_mask = sorted_mask_components[0] for component in sorted_mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias sorted_attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: sorted_attention_bias = None # apply attention x = local_dot_product_attention( block_query, block_key, block_value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) sorted_x = local_dot_product_attention( block_query, sorted_key, sorted_value, dtype=dtype, axis=attention_axis, bias=sorted_attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) x = x + sorted_x x = jnp.reshape(x, (bs, qlength, num_heads, head_dim)) # back to the original inputs dimensions out = nn.DenseGeneral( x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, max_length=512, ignore_dot_product=True, synthesizer_mode='factorized_random', k=32): """Applies multi-head synthesizer attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. max_length: int, the maximum supported sequence length. ignore_dot_product: bool, to ignore the dot product attention or not. synthesizer_mode: str support 'dense' and 'random' or 'dense+random' k: int, low rank factorized attention. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') assert inputs_q.ndim == 3 if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] qlength = inputs_q.shape[-2] kvlength = inputs_kv.shape[-2] if ignore_dot_product: value = dense(inputs_kv, dtype=dtype, name='value') key = value query = inputs_q else: query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) syn_weights_list = [] logging.info(synthesizer_mode) if 'random' in synthesizer_mode: if 'factorized_random' in synthesizer_mode: logging.info('Using factorized random') rand_syn_weights1 = self.param('random1', (num_heads, max_length, k), kernel_init) rand_syn_weights2 = self.param('random2', (num_heads, k, max_length), kernel_init) rand_syn_weights1 = rand_syn_weights1[:, :qlength, :] rand_syn_weights2 = rand_syn_weights2[:, :, :kvlength] rand_syn_weights = jnp.einsum('hlk,hkn->hln', rand_syn_weights1, rand_syn_weights2) rand_syn_weights = jax.lax.broadcast(rand_syn_weights, (inputs_q.shape[0], )) syn_weights_list.append(rand_syn_weights) else: rand_syn_weights = self.param( 'random', (num_heads, max_length, max_length), kernel_init) rand_syn_weights = rand_syn_weights[:, :qlength, :kvlength] rand_syn_weights = jax.lax.broadcast(rand_syn_weights, (inputs_q.shape[0], )) syn_weights_list.append(rand_syn_weights) if 'dense' in synthesizer_mode: dense_syn = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, name='dense_syn', dtype=dtype) # TODO(yitay): Change this to nn.Dense and make sure it works dense_syn_length = nn.linear.DenseGeneral.partial( axis=-1, features=(max_length), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, name='dense_syn2', dtype=dtype) proj = dense_syn(inputs_q, dtype=dtype, name='dense_syn') proj = jax.nn.relu(proj) proj = dense_syn_length(proj, dtype=dtype, name='dense_syn_len') # TODO(yitay) check if this reshape is needed dense_syn_weights = proj.reshape( (inputs_q.shape[0], num_heads, qlength, max_length)) dense_syn_weights = dense_syn_weights[:, :, :, :qlength] syn_weights_list.append(dense_syn_weights) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store( onp.array((key.ndim, ) + key.shape[-2:], dtype=onp.int32)) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # create attention masks mask_components = [] if causal_mask: if cache and not self.is_initializing(): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if not ignore_dot_product: if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = synthetic_attention(query, key, value, syn_weights_list, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, ignore_dot_product=ignore_dot_product) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True): """Applies linear attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies linear attention and project the results to an output vector. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ if padding_mask is not None: NotImplementedError( 'Currently, we do not support autoregresive decoding.') assert causal_mask or not cache, ( 'Caching is only support for causal attention.') assert inputs_q.ndim == 3 if inputs_kv is None: inputs_kv = inputs_q features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if cache: raise NotImplementedError( 'Decoding not supported in LinearAttention.') # apply regular dot product attention x = linear_attention(query, key, value, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
def apply(self, inputs_q, inputs_kv, num_heads, sliding_window_size=512, global_mask=None, causal_mask=False, dtype=jnp.float32, qkv_features=None, out_features=None, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True): """Applies longformer multi-head dot product attention on the input data. Args: inputs_q: input queries of shape `[bs, seq_len, features]`. inputs_kv: key/values of shape `[bs, seq_len, features]` or `None` for self-attention, in which case key/values will be derived from inputs_q. num_heads: number of attention heads (should divide number of features). sliding_window_size: size of sliding window attention to use. global_mask: boolean matrix of shape `[bs, seq_len]`, where `True` indicates that the position is globally attended. By default, no global attention is used. causal_mask: If true, apply causal attention masking. dtype: the dtype of the computation (default: float32). qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. broadcast_dropout: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey to be use for dropout. dropout_rate: dropout rate. deterministic: if true, apply dropout, else don't. precision: numerical precision of the computation. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: whether pointwise QKVO dense transforms use bias. query, key, value, and returns output of shape `[bs, seq_len, num_heads, value_channels]`. Returns: output of shape `[bs, seq_len, features]`. """ if inputs_kv is None: inputs_kv = inputs_q batch_size = inputs_q.shape[0] features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] seq_len = inputs_q.shape[1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) query_sw = dense(inputs_q, dtype=dtype, name='query_sliding_window') key_sw = dense(inputs_kv, dtype=dtype, name='key_sliding_window') value_sw = dense(inputs_kv, dtype=dtype, name='value_sliding_window') query_global = dense(inputs_q, dtype=dtype, name='query_global') key_global = dense(inputs_kv, dtype=dtype, name='key_global') value_global = dense(inputs_kv, dtype=dtype, name='value_global') if global_mask is None: global_mask = jnp.full((batch_size, seq_len), False) full_global_mask = _build_global_mask(global_mask) sliding_window_mask = _build_sliding_window_mask( window_size=sliding_window_size, global_mask=global_mask) x_sw = _get_attention_result(query=query_sw, key=key_sw, value=value_sw, dtype=dtype, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, mask=sliding_window_mask, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=segmentation, key_segmentation=key_segmentation, apply_causal_mask=causal_mask) x_global = _get_attention_result(query=query_global, key=key_global, value=value_global, dtype=dtype, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, mask=full_global_mask, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=segmentation, key_segmentation=key_segmentation, apply_causal_mask=causal_mask) x = jnp.where(global_mask[:, :, jnp.newaxis, jnp.newaxis], x_global, x_sw) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, low_rank_features=16, max_len=1000): """Applies Linformer's low-rank attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. low_rank_features: int: how many low-rank projected features. max_len: int maximum sequence length. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') assert inputs_q.ndim == 3 if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) def low_rank_projection(inputs, kernel, precision): """low rank projection.""" input_dim = inputs.shape[1] # this kernel/parameter relies on sequence length kernel = kernel[:input_dim, :] inputs = inputs.transpose((0, 3, 2, 1)) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1, ), (0, )), ((), ())), precision=precision) y = y.transpose((0, 3, 2, 1)) return y # Shared Kernel for low-rank length dimension projections. low_rank_kernel = self.param('lr_kernel', (max_len, low_rank_features), kernel_init) key = low_rank_projection(key, low_rank_kernel, precision) value = low_rank_projection(value, low_rank_kernel, precision) if cache: raise NotImplementedError('Decoding not supported in Linformer.') # TODO(yitay) Does Linformer care about masks? # Since everything is mixed in length dimension are masks relevant? # apply regular dot product attention x = dot_product_attention(query, key, value, dtype=dtype, axis=attention_axis, bias=None, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
def apply(self, inputs_q, inputs_kv, num_heads, block_size=64, num_rand_blocks=3, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, connectivity_seed=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, length, features]`. inputs_kv: key/values of shape `[bs, length, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. block_size: Size for local attention around diagonal of attention. num_rand_blocks: int. Number of random chunks per row. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. connectivity_seed: Seed for random block sparse attention. Returns: output of shape `[bs, length, features]`. """ orig_seqlen = inputs_q.shape[-2] logging.info(inputs_q) extra_len = block_size - (orig_seqlen % block_size) pad_width = jnp.array([[0, 0], [0, extra_len], [0, 0]]) mask_pad = jnp.array([[0, 0], [0, extra_len], [0, 0]]) padding_mask = jnp.pad(padding_mask, mask_pad, constant_values=-1e9) inputs_q = jnp.pad(inputs_q, pad_width) if inputs_kv is not None: inputs_kv = jnp.pad(inputs_kv, pad_width) assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if cache: assert isinstance( cache, attention.Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store( onp.array((key.ndim, ) + key.shape[-2:], dtype=onp.int32)) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, attention._CacheEntry): # pylint: disable=protected-access raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] if causal_mask: # Falls back to full attention with a causal mask. # create attention masks mask_components = [] if causal_mask: if cache and not self.is_initializing(): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append( attention._make_causal_mask(key, attention_axis)) # pylint: disable=protected-access if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = attention.make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = attention.make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None x = nn.dot_product_attention(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) else: if connectivity_seed is None: path = self._get_construction_frame().path connectivity_seed = hash(path) % 2**32 # apply attention input_mask = None if padding_mask is not None: input_mask = padding_mask.astype(key.dtype) x = sparse_dot_product_attention( query, key, value, connectivity_seed=connectivity_seed, input_mask=input_mask, block_size=block_size, num_rand_blocks=num_rand_blocks) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') out = out[:, :orig_seqlen, :] return out