Beispiel #1
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              num_partitions=2):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))
        if num_partitions > 1:
            partitions = P(1, 1, num_partitions, 1)
            query = with_sharding_constraint(query, partitions)
            key = with_sharding_constraint(key, partitions)
            value = with_sharding_constraint(value, partitions)

        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(lambda: (key.ndim, key.shape[-2:]))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                i = cache_entry.i
                one_hot_indices = jax.nn.one_hot(i, cshape[3],
                                                 dtype=key.dtype).reshape(
                                                     (1, 1, 1, cshape[3]))
                key = key.transpose((0, 2, 3, 1))
                key = cache_entry.key + key * one_hot_indices
                value = value.transpose((0, 2, 3, 1))
                value = cache_entry.value + value * one_hot_indices

                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                key = key.transpose((0, 3, 1, 2))
                value = value.transpose((0, 3, 1, 2))
                cshape = (cshape[0], cshape[3], cshape[1], cshape[2])

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(np.take(key.shape, attention_axis))
                attn_size = np.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                             padding_mask_key=key_padding_mask,
                                             query_shape=query.shape,
                                             key_shape=key.shape,
                                             attention_axis=attention_axis)
            mask_components.append(padding_mask)

        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')
        if num_partitions > 1:
            x = with_sharding_constraint(x, None)

        return out
Beispiel #2
0
  def apply(self,
            inputs_q,
            inputs_kv,
            num_heads,
            dtype=jnp.float32,
            qkv_features=None,
            out_features=None,
            attention_axis=None,
            causal_mask=False,
            padding_mask=None,
            key_padding_mask=None,
            segmentation=None,
            key_segmentation=None,
            cache=None,
            broadcast_dropout=True,
            dropout_rng=None,
            dropout_rate=0.,
            deterministic=False,
            precision=None,
            kernel_init=nn.linear.default_kernel_init,
            bias_init=nn.initializers.zeros,
            bias=True,
            block_size=50,
            max_num_blocks=25,
            sort_activation='softmax'):
    """Applies multi-head sinkhorn attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      block_size: int, block size.
      max_num_blocks:  int, max num blocks.
      sort_activation: str {softmax, sinkhorn, gumbel_sinkhorn}

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    assert inputs_q.ndim == 3

    if inputs_kv is None:
      inputs_kv = inputs_q

    if attention_axis is None:
      attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = nn.DenseGeneral.partial(
        axis=-1,
        features=(num_heads, head_dim),
        kernel_init=kernel_init,
        bias_init=bias_init,
        bias=bias,
        precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    qlength = inputs_q.shape[-2]
    bs = inputs_q.shape[0]
    kvlength = inputs_kv.shape[-2]

    query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                         dense(inputs_kv, dtype=dtype, name='key'),
                         dense(inputs_kv, dtype=dtype, name='value'))

    if cache:
      assert isinstance(cache, Cache), 'cache must be an instance of Cache'
      if self.is_initializing():
        cache.store(onp.array((key.ndim,) + key.shape[-2:], dtype=onp.int32))
      else:
        cache_entry = cache.retrieve(None)
        expected_shape = list(cache_entry.key.shape[:-2])
        for attn_dim in attention_axis:
          expected_shape[attn_dim] = 1
        expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
        if expected_shape != inputs_q.shape:
          raise ValueError('Invalid shape provided, '
                           'expected shape %s instead got %s.' %
                           (expected_shape, inputs_q.shape))

        if not isinstance(cache_entry, _CacheEntry):
          raise ValueError('Cache is not initialized.')

        cshape = cache_entry.key.shape
        indices = [0] * len(cshape)
        i = cache_entry.i
        attn_size = onp.prod(onp.take(cshape, attention_axis))
        for attn_dim in attention_axis:
          attn_size //= cshape[attn_dim]
          indices[attn_dim] = i // attn_size
          i = i % attn_size

        key = lax.dynamic_update_slice(cache_entry.key, key, indices)
        value = lax.dynamic_update_slice(cache_entry.value, value, indices)
        one = jnp.array(1, jnp.uint32)
        cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                          key=key,
                                          value=value)
        cache.store(cache_entry)

        key_padding_mask = jnp.broadcast_to(
            (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
        key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]

    # block reshape before attention
    num_query_blocks = qlength // block_size
    num_kv_blocks = kvlength // block_size

    block_query = jnp.reshape(
        query, (bs, block_size, num_query_blocks, num_heads, head_dim))
    block_key = jnp.reshape(
        key, (bs, block_size, num_kv_blocks, num_heads, head_dim))
    block_value = jnp.reshape(
        value, (bs, block_size, num_kv_blocks, num_heads, head_dim))

    if causal_mask:
      # causal masking needs to not have blocks with mixed information.
      sum_key = jnp.cumsum(block_key, axis=1)
      sum_key = sum_key[:, 0, :, :, :]  # take first item
    else:
      sum_key = jnp.sum(block_key, axis=1)

    # sort net on head_dim dimensions
    sort_out = nn.DenseGeneral(sum_key, axis=-1,
                               features=(max_num_blocks),
                               kernel_init=kernel_init,
                               bias_init=bias_init,
                               bias=bias,
                               precision=precision)

    # (bs x num_key_blocks x num_heads x num_key_blocks
    sort_out = sort_out[:, :, :, :num_query_blocks]

    # simple softmax sorting first.

    if sort_activation == 'sinkhorn':
      permutation = sinkhorn_operator(
          jnp.reshape(sort_out, (-1, num_kv_blocks, num_query_blocks)),
          causal=causal_mask)
      permutation = jnp.reshape(permutation, (-1, num_kv_blocks, num_heads,
                                              num_query_blocks))
    else:
      if causal_mask:
        block_mask = _make_causal_mask(key, attention_axis)
        sort_out += block_mask
      permutation = jax.nn.softmax(sort_out, axis=-1)

    sorted_key = jnp.einsum('bskhd,bnhl->bsnhd', block_key, permutation)
    sorted_value = jnp.einsum('bskhd,bnhl->bsnhd', block_value, permutation)

    # create attention masks
    mask_components = []
    sorted_mask_components = []

    if causal_mask:
      # TODO(yitay): Test this causal masking.
      if cache and not self.is_initializing():
        bias_pre_shape = (1,) * (key.ndim - 1)
        attn_shape = tuple(onp.take(key.shape, attention_axis))
        attn_size = onp.prod(attn_shape)
        ii = jnp.arange(attn_size, dtype=jnp.uint32)
        mask = ii < cache_entry.i
        mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
      else:
        mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
      # divide padding mask into block
      padding_mask = jnp.reshape(padding_mask,
                                 (bs * num_query_blocks, block_size, 1))
      if key_padding_mask is None:
        key_padding_mask = padding_mask

      padding_mask = make_padding_mask(
          padding_mask_query=padding_mask,
          padding_mask_key=key_padding_mask,
          query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim),
          key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim),
          attention_axis=attention_axis)

      padding_mask = jnp.reshape(padding_mask,
                                 (bs, num_query_blocks, block_size, block_size))
      mask_components.append(padding_mask)
      sorted_padding_mask = jnp.einsum('bksj,bnhl->bnsj', padding_mask,
                                       permutation)
      sorted_mask_components.append(sorted_padding_mask)

    if segmentation is not None:
      if key_segmentation is None:
        key_segmentation = segmentation
      segmentation_mask = make_padding_mask(
          padding_mask_query=segmentation,
          padding_mask_key=key_segmentation,
          query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim),
          key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim),
          attention_axis=attention_axis,
          segmentation_mask=True)
      segmentation_mask = jnp.reshape(segmentation_mask,
                                      (bs, num_query_blocks, block_size,
                                       block_size))
      mask_components.append(segmentation_mask)
      sorted_segmentation_mask = jnp.einsum('bksj,bnhl->bnsj',
                                            segmentation_mask,
                                            permutation)
      sorted_mask_components.append(sorted_segmentation_mask)

    if mask_components:
      attention_mask = mask_components[0]
      for component in mask_components[1:]:
        attention_mask = jnp.logical_and(attention_mask, component)

      # attention mask in the form of attention bias
      attention_bias = lax.select(
          attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype),
          jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
      attention_bias = None

    if sorted_mask_components:
      attention_mask = sorted_mask_components[0]
      for component in sorted_mask_components[1:]:
        attention_mask = jnp.logical_and(attention_mask, component)

      # attention mask in the form of attention bias
      sorted_attention_bias = lax.select(
          attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype),
          jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
      sorted_attention_bias = None

    # apply attention
    x = local_dot_product_attention(
        block_query,
        block_key,
        block_value,
        dtype=dtype,
        axis=attention_axis,
        bias=attention_bias,
        precision=precision,
        dropout_rng=dropout_rng,
        dropout_rate=dropout_rate,
        broadcast_dropout=broadcast_dropout,
        deterministic=deterministic)

    sorted_x = local_dot_product_attention(
        block_query,
        sorted_key,
        sorted_value,
        dtype=dtype,
        axis=attention_axis,
        bias=sorted_attention_bias,
        precision=precision,
        dropout_rng=dropout_rng,
        dropout_rate=dropout_rate,
        broadcast_dropout=broadcast_dropout,
        deterministic=deterministic)

    x = x + sorted_x

    x = jnp.reshape(x, (bs, qlength, num_heads, head_dim))

    # back to the original inputs dimensions
    out = nn.DenseGeneral(
        x,
        features=features,
        axis=(-2, -1),
        kernel_init=kernel_init,
        bias_init=bias_init,
        bias=bias,
        dtype=dtype,
        precision=precision,
        name='out')

    return out
Beispiel #3
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              max_length=512,
              ignore_dot_product=True,
              synthesizer_mode='factorized_random',
              k=32):
        """Applies multi-head synthesizer attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      max_length: int, the maximum supported sequence length.
      ignore_dot_product: bool, to ignore the dot product attention or not.
      synthesizer_mode: str support 'dense' and 'random' or 'dense+random'
      k: int, low rank factorized attention.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        assert inputs_q.ndim == 3

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        qlength = inputs_q.shape[-2]
        kvlength = inputs_kv.shape[-2]

        if ignore_dot_product:
            value = dense(inputs_kv, dtype=dtype, name='value')
            key = value
            query = inputs_q
        else:
            query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                                 dense(inputs_kv, dtype=dtype, name='key'),
                                 dense(inputs_kv, dtype=dtype, name='value'))

        syn_weights_list = []
        logging.info(synthesizer_mode)
        if 'random' in synthesizer_mode:
            if 'factorized_random' in synthesizer_mode:
                logging.info('Using factorized random')
                rand_syn_weights1 = self.param('random1',
                                               (num_heads, max_length, k),
                                               kernel_init)
                rand_syn_weights2 = self.param('random2',
                                               (num_heads, k, max_length),
                                               kernel_init)
                rand_syn_weights1 = rand_syn_weights1[:, :qlength, :]
                rand_syn_weights2 = rand_syn_weights2[:, :, :kvlength]
                rand_syn_weights = jnp.einsum('hlk,hkn->hln',
                                              rand_syn_weights1,
                                              rand_syn_weights2)
                rand_syn_weights = jax.lax.broadcast(rand_syn_weights,
                                                     (inputs_q.shape[0], ))
                syn_weights_list.append(rand_syn_weights)
            else:
                rand_syn_weights = self.param(
                    'random', (num_heads, max_length, max_length), kernel_init)
                rand_syn_weights = rand_syn_weights[:, :qlength, :kvlength]
                rand_syn_weights = jax.lax.broadcast(rand_syn_weights,
                                                     (inputs_q.shape[0], ))
                syn_weights_list.append(rand_syn_weights)
        if 'dense' in synthesizer_mode:
            dense_syn = nn.DenseGeneral.partial(axis=-1,
                                                features=(num_heads, head_dim),
                                                kernel_init=kernel_init,
                                                bias_init=bias_init,
                                                bias=bias,
                                                precision=precision,
                                                name='dense_syn',
                                                dtype=dtype)
            # TODO(yitay): Change this to nn.Dense and make sure it works
            dense_syn_length = nn.linear.DenseGeneral.partial(
                axis=-1,
                features=(max_length),
                kernel_init=kernel_init,
                bias_init=bias_init,
                bias=bias,
                precision=precision,
                name='dense_syn2',
                dtype=dtype)
            proj = dense_syn(inputs_q, dtype=dtype, name='dense_syn')
            proj = jax.nn.relu(proj)
            proj = dense_syn_length(proj, dtype=dtype, name='dense_syn_len')
            # TODO(yitay) check if this reshape is needed
            dense_syn_weights = proj.reshape(
                (inputs_q.shape[0], num_heads, qlength, max_length))
            dense_syn_weights = dense_syn_weights[:, :, :, :qlength]
            syn_weights_list.append(dense_syn_weights)
        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(
                    onp.array((key.ndim, ) + key.shape[-2:], dtype=onp.int32))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                indices = [0] * len(cshape)
                i = cache_entry.i
                attn_size = onp.prod(onp.take(cshape, attention_axis))
                for attn_dim in attention_axis:
                    attn_size //= cshape[attn_dim]
                    indices[attn_dim] = i // attn_size
                    i = i % attn_size

                key = lax.dynamic_update_slice(cache_entry.key, key, indices)
                value = lax.dynamic_update_slice(cache_entry.value, value,
                                                 indices)
                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if not ignore_dot_product:
            if padding_mask is not None:
                if key_padding_mask is None:
                    key_padding_mask = padding_mask
                padding_mask = make_padding_mask(
                    padding_mask_query=padding_mask,
                    padding_mask_key=key_padding_mask,
                    query_shape=query.shape,
                    key_shape=key.shape,
                    attention_axis=attention_axis)
                mask_components.append(padding_mask)

            if segmentation is not None:
                if key_segmentation is None:
                    key_segmentation = segmentation
                segmentation_mask = make_padding_mask(
                    padding_mask_query=segmentation,
                    padding_mask_key=key_segmentation,
                    query_shape=query.shape,
                    key_shape=key.shape,
                    attention_axis=attention_axis,
                    segmentation_mask=True)
                mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = synthetic_attention(query,
                                key,
                                value,
                                syn_weights_list,
                                dtype=dtype,
                                axis=attention_axis,
                                bias=attention_bias,
                                precision=precision,
                                dropout_rng=dropout_rng,
                                dropout_rate=dropout_rate,
                                broadcast_dropout=broadcast_dropout,
                                deterministic=deterministic,
                                ignore_dot_product=ignore_dot_product)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')

        return out
Beispiel #4
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True):
        """Applies linear attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies linear attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        if padding_mask is not None:
            NotImplementedError(
                'Currently, we do not support autoregresive decoding.')

            assert causal_mask or not cache, (
                'Caching is only support for causal attention.')

        assert inputs_q.ndim == 3

        if inputs_kv is None:
            inputs_kv = inputs_q

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))

        if cache:
            raise NotImplementedError(
                'Decoding not supported in LinearAttention.')

        # apply regular dot product attention
        x = linear_attention(query,
                             key,
                             value,
                             dropout_rng=dropout_rng,
                             dropout_rate=dropout_rate,
                             broadcast_dropout=broadcast_dropout,
                             deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')

        return out
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              sliding_window_size=512,
              global_mask=None,
              causal_mask=False,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True):
        """Applies longformer multi-head dot product attention on the input data.

    Args:
      inputs_q: input queries of shape `[bs, seq_len, features]`.
      inputs_kv: key/values of shape `[bs, seq_len, features]` or `None` for
        self-attention, in which case key/values will be derived from inputs_q.
      num_heads: number of attention heads (should divide number of features).
      sliding_window_size: size of sliding window attention to use.
      global_mask: boolean matrix of shape `[bs, seq_len]`, where `True`
        indicates that the position is globally attended. By default, no global
        attention is used.
      causal_mask: If true, apply causal attention masking.
      dtype: the dtype of the computation (default: float32).
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      broadcast_dropout: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey to be use for dropout.
      dropout_rate: dropout rate.
      deterministic: if true, apply dropout, else don't.
      precision: numerical precision of the computation.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: whether pointwise QKVO dense transforms use bias. query, key, value,
        and returns output of shape `[bs, seq_len, num_heads, value_channels]`.

    Returns:
      output of shape `[bs, seq_len, features]`.
    """
        if inputs_kv is None:
            inputs_kv = inputs_q

        batch_size = inputs_q.shape[0]
        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]
        seq_len = inputs_q.shape[1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)

        query_sw = dense(inputs_q, dtype=dtype, name='query_sliding_window')
        key_sw = dense(inputs_kv, dtype=dtype, name='key_sliding_window')
        value_sw = dense(inputs_kv, dtype=dtype, name='value_sliding_window')

        query_global = dense(inputs_q, dtype=dtype, name='query_global')
        key_global = dense(inputs_kv, dtype=dtype, name='key_global')
        value_global = dense(inputs_kv, dtype=dtype, name='value_global')

        if global_mask is None:
            global_mask = jnp.full((batch_size, seq_len), False)

        full_global_mask = _build_global_mask(global_mask)

        sliding_window_mask = _build_sliding_window_mask(
            window_size=sliding_window_size, global_mask=global_mask)

        x_sw = _get_attention_result(query=query_sw,
                                     key=key_sw,
                                     value=value_sw,
                                     dtype=dtype,
                                     precision=precision,
                                     dropout_rng=dropout_rng,
                                     dropout_rate=dropout_rate,
                                     broadcast_dropout=broadcast_dropout,
                                     deterministic=deterministic,
                                     mask=sliding_window_mask,
                                     padding_mask=padding_mask,
                                     key_padding_mask=key_padding_mask,
                                     segmentation=segmentation,
                                     key_segmentation=key_segmentation,
                                     apply_causal_mask=causal_mask)

        x_global = _get_attention_result(query=query_global,
                                         key=key_global,
                                         value=value_global,
                                         dtype=dtype,
                                         precision=precision,
                                         dropout_rng=dropout_rng,
                                         dropout_rate=dropout_rate,
                                         broadcast_dropout=broadcast_dropout,
                                         deterministic=deterministic,
                                         mask=full_global_mask,
                                         padding_mask=padding_mask,
                                         key_padding_mask=key_padding_mask,
                                         segmentation=segmentation,
                                         key_segmentation=key_segmentation,
                                         apply_causal_mask=causal_mask)

        x = jnp.where(global_mask[:, :, jnp.newaxis, jnp.newaxis], x_global,
                      x_sw)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')

        return out
Beispiel #6
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              low_rank_features=16,
              max_len=1000):
        """Applies Linformer's low-rank attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      low_rank_features: int: how many low-rank projected features.
      max_len: int maximum sequence length.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        assert inputs_q.ndim == 3

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]

        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))

        def low_rank_projection(inputs, kernel, precision):
            """low rank projection."""
            input_dim = inputs.shape[1]
            # this kernel/parameter relies on sequence length
            kernel = kernel[:input_dim, :]
            inputs = inputs.transpose((0, 3, 2, 1))
            y = lax.dot_general(inputs,
                                kernel,
                                (((inputs.ndim - 1, ), (0, )), ((), ())),
                                precision=precision)
            y = y.transpose((0, 3, 2, 1))
            return y

        # Shared Kernel for low-rank length dimension projections.
        low_rank_kernel = self.param('lr_kernel', (max_len, low_rank_features),
                                     kernel_init)

        key = low_rank_projection(key, low_rank_kernel, precision)
        value = low_rank_projection(value, low_rank_kernel, precision)

        if cache:
            raise NotImplementedError('Decoding not supported in Linformer.')

        # TODO(yitay) Does Linformer care about masks?
        # Since everything is mixed in length dimension are masks relevant?

        # apply regular dot product attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=None,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')

        return out
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              block_size=64,
              num_rand_blocks=3,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              connectivity_seed=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, length, features]`.
      inputs_kv: key/values of shape `[bs, length, features]` or
        None for self-attention, inn which case key/values will be derived from
        inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      block_size: Size for local attention around diagonal of attention.
      num_rand_blocks: int. Number of random chunks per row.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      connectivity_seed: Seed for random block sparse attention.

    Returns:
      output of shape `[bs, length, features]`.
    """

        orig_seqlen = inputs_q.shape[-2]
        logging.info(inputs_q)
        extra_len = block_size - (orig_seqlen % block_size)
        pad_width = jnp.array([[0, 0], [0, extra_len], [0, 0]])
        mask_pad = jnp.array([[0, 0], [0, extra_len], [0, 0]])
        padding_mask = jnp.pad(padding_mask, mask_pad, constant_values=-1e9)

        inputs_q = jnp.pad(inputs_q, pad_width)
        if inputs_kv is not None:
            inputs_kv = jnp.pad(inputs_kv, pad_width)

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))

        if cache:
            assert isinstance(
                cache, attention.Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(
                    onp.array((key.ndim, ) + key.shape[-2:], dtype=onp.int32))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, attention._CacheEntry):  # pylint: disable=protected-access
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                indices = [0] * len(cshape)
                i = cache_entry.i
                attn_size = onp.prod(onp.take(cshape, attention_axis))
                for attn_dim in attention_axis:
                    attn_size //= cshape[attn_dim]
                    indices[attn_dim] = i // attn_size
                    i = i % attn_size

                key = lax.dynamic_update_slice(cache_entry.key, key, indices)
                value = lax.dynamic_update_slice(cache_entry.value, value,
                                                 indices)
                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        if causal_mask:  # Falls back to full attention with a causal mask.
            # create attention masks
            mask_components = []

            if causal_mask:
                if cache and not self.is_initializing():
                    bias_pre_shape = (1, ) * (key.ndim - 1)
                    attn_shape = tuple(onp.take(key.shape, attention_axis))
                    attn_size = onp.prod(attn_shape)
                    ii = jnp.arange(attn_size, dtype=jnp.uint32)
                    mask = ii < cache_entry.i
                    mask_components.append(
                        mask.reshape(bias_pre_shape + attn_shape))
                else:
                    mask_components.append(
                        attention._make_causal_mask(key, attention_axis))  # pylint: disable=protected-access

            if padding_mask is not None:
                if key_padding_mask is None:
                    key_padding_mask = padding_mask
                padding_mask = attention.make_padding_mask(
                    padding_mask_query=padding_mask,
                    padding_mask_key=key_padding_mask,
                    query_shape=query.shape,
                    key_shape=key.shape,
                    attention_axis=attention_axis)
                mask_components.append(padding_mask)

            if segmentation is not None:
                if key_segmentation is None:
                    key_segmentation = segmentation
                segmentation_mask = attention.make_padding_mask(
                    padding_mask_query=segmentation,
                    padding_mask_key=key_segmentation,
                    query_shape=query.shape,
                    key_shape=key.shape,
                    attention_axis=attention_axis,
                    segmentation_mask=True)
                mask_components.append(segmentation_mask)

            if mask_components:
                attention_mask = mask_components[0]
                for component in mask_components[1:]:
                    attention_mask = jnp.logical_and(attention_mask, component)

                # attention mask in the form of attention bias
                attention_bias = lax.select(
                    attention_mask > 0,
                    jnp.full(attention_mask.shape, 0.).astype(dtype),
                    jnp.full(attention_mask.shape, -1e10).astype(dtype))
            else:
                attention_bias = None
            x = nn.dot_product_attention(query,
                                         key,
                                         value,
                                         dtype=dtype,
                                         axis=attention_axis,
                                         bias=attention_bias,
                                         precision=precision,
                                         dropout_rng=dropout_rng,
                                         dropout_rate=dropout_rate,
                                         broadcast_dropout=broadcast_dropout,
                                         deterministic=deterministic)
        else:
            if connectivity_seed is None:
                path = self._get_construction_frame().path
                connectivity_seed = hash(path) % 2**32
            # apply attention
            input_mask = None
            if padding_mask is not None:
                input_mask = padding_mask.astype(key.dtype)
            x = sparse_dot_product_attention(
                query,
                key,
                value,
                connectivity_seed=connectivity_seed,
                input_mask=input_mask,
                block_size=block_size,
                num_rand_blocks=num_rand_blocks)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')

        out = out[:, :orig_seqlen, :]

        return out