Esempio n. 1
0
  def __call__(self,
               inputs):
    """Applies AddPositionEmbs module.

    Args:
      inputs: input data `[batch_size, ..., length, dim]`

    Returns:
      New embedding `[batch_size, ..., length, dim]`
    """
    cfg = self.config

    assert inputs.ndim >= 3
    flat_inputs = inputs.reshape((-1, inputs.shape[-2], inputs.shape[-1]))
    length = flat_inputs.shape[1]
    pos_emb_shape = (1, cfg.max_len, flat_inputs.shape[-1])
    if cfg.posemb_init is None:
      pos_embedding = sinusoidal_init(max_len=cfg.max_len)(
          None, pos_emb_shape, None)
    else:
      pos_embedding = self.param('pos_embedding',
                                 cfg.posemb_init,
                                 pos_emb_shape)
    pe = pos_embedding[:, :length, :]
    # We abuse the same attention Cache mechanism to run positional embeddings
    # in fast predict mode. We could use state variables instead, but this
    # simplifies invocation with a single top-level cache context manager.
    # We only use the cache's position index for tracking decoding position.
    if self.cache:
      is_initialized = self.has_variable('cache', 'cache_index')
      cache_index = self.variable('cache', 'cache_index',
                                  lambda: jnp.array(0, dtype=jnp.uint32))
      if is_initialized:
        i = cache_index.value
        cache_index.value = i + 1
        _, _, df = pos_embedding.shape
        pe = lax.dynamic_slice(pos_embedding,
                               jnp.array((0, i, 0)),
                               jnp.array((1, 1, df)))
    return (flat_inputs + pe).reshape(inputs.shape)
Esempio n. 2
0
    def apply(self,
              inputs,
              max_len=2048,
              posemb_init=nn.initializers.normal(stddev=1.0),
              cache=None):
        """Applies AddPositionEmbs module.

    Args:
      inputs: input data
      max_len: maximum possible length for the input
      posemb_init: positional embedding initializer
      cache: flax attention cache for fast decoding.

    Returns:
      output: `(bs, timesteps, in_dim)`
    """
        assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                                  ' but it is: %d' % inputs.ndim)
        length = inputs.shape[1]
        pos_emb_shape = (1, max_len, inputs.shape[-1])
        pos_embedding = self.param('pos_embedding', pos_emb_shape, posemb_init)
        pe = pos_embedding[:, :length, :]
        # We abuse the same attention Cache mechanism to run positional embeddings
        # in fast predict mode. We could use state variables instead, but this
        # simplifies invocation with a single top-level cache context manager.
        # We only use the cache's position index for tracking decoding position.
        if cache:
            if self.is_initializing():
                cache.store(lambda: (4, (1, 1)))
            else:
                cache_entry = cache.retrieve(None)
                i = cache_entry.i
                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one)
                cache.store(cache_entry)
                _, _, df = pos_embedding.shape
                pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)),
                                       jnp.array((1, 1, df)))
        return inputs + pe
Esempio n. 3
0
    def batch_fn(state):
        """Function for batching."""
        key, data, i, num_batches = state

        slice_start = ([i * batch_size_per_device, 0,
                        0], [i * batch_size_per_device, 0])

        slice_size = ([batch_size_per_device, 32,
                       32 * 3], [batch_size_per_device, 10])

        batch = [
            lax.dynamic_slice(x, start, size)
            for x, start, size in zip(data, slice_start, slice_size)
        ]

        if transform is not None:
            key, subkey = random.split(key)
            batch = transform(subkey, batch)

        i = i + 1
        key, data, i = lax.cond(i >= num_batches, (key, data), shuffle,
                                (key, data, i), lambda x: x)

        return batch, (key, data, i, num_batches)
Esempio n. 4
0
        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Esempio n. 5
0
 def f_jax(x, idx):  # x.shape : (4,)
     return lax.dynamic_slice(x, idx, slice_sizes=(2, ))
Esempio n. 6
0
                  poly_axes=[0, None]),

    _make_harness("cummax", "",
                  lambda x: lax_control_flow.cummax(x, axis=1, reverse=False),
                  [RandArg((3, 4, 5), _f32)],
                  poly_axes=[0]),

    _make_harness("dot_general", "",
                  lambda lhs, rhs: lax.dot_general(lhs, rhs,
                                                   dimension_numbers=(((2,), (1,)), ((0,), (0,)))),
                  [RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("dynamic_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("dynamic_update_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "0",
                  lambda x: jnp.einsum("...i->...", x),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "1",
Esempio n. 7
0
    def apply_fast_fun(params, inputs, cache, index, **kwargs):
        W, b = params

        batch = inputs.shape[lhs_spec.index('N')]
        L = cache.shape[lhs_spec.index('W')]
        in_chan = inputs.shape[lhs_spec.index('C')]

        index_h, index_w = index
        if exclusive:
            index_h_in, index_w_in = prev_index_2d(index_h, index_w, L)
        else:
            index_h_in, index_w_in = index_h, index_w

        # First, update the cache

        def _add(cache):
            # return cache.at[:, -1, index_w_in, :].set(inputs)
            return lax.dynamic_update_slice(cache, inputs,
                                            (0, -1, index_w_in, 0))

        def _shift(cache):
            return jnp.concatenate(
                [
                    cache[:, 1:, :, :],
                    jnp.zeros((batch, 1, L, in_chan), dtype=cache.dtype)
                ],
                axis=1,
            )

        def _new_row(cache):
            return lax.cond(
                index_w_in == 0,
                lambda x: _add(_shift(x)),
                lambda x: _shift(_add(x)),
                cache,
            )

        def _update(cache):
            return lax.cond(index_w == 0, _new_row, _add, cache)

        cache = lax.cond(index_h_in >= 0, _update, lambda x: x, cache)

        # Then, use the cache to compute the outputs (the inputs are not used)

        # Zero padding
        cache_slice = jnp.pad(cache, (
            (0, 0),
            (0, 0),
            (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
            (0, 0),
        ))

        # cache = cache[:, :, index_w : index_w + recep_w, :]
        cache_slice = lax.dynamic_slice(cache_slice, (0, 0, index_w, 0),
                                        (batch, recep_h, recep_w, in_chan))

        out = lax.conv_general_dilated(
            cache_slice,
            mask * W,
            window_strides=(1, 1),
            padding='VALID',
            lhs_dilation=(1, 1),
            rhs_dilation=dilation,
            dimension_numbers=dimension_numbers,
        )
        assert out.shape == (batch, 1, 1, out_chan)
        out += b
        return out, cache
Esempio n. 8
0
 def f_jax(arr):
   return lax.dynamic_slice(arr, [100], [1])  # out of bounds, should return the last element
Esempio n. 9
0
 def scalar_f(x):
   return lax.dynamic_slice(x, [], [])
Esempio n. 10
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        qkv_out = self.c_attn(hidden_states)
        query, key, value = jnp.split(qkv_out, 3, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
        )

        # usual dot product attention
        attn_output = dot_product_attention(
            query,
            key,
            value,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        # TODO: at the moment it's not possible to retrieve attn_weights from
        # dot_product_attention, but should be in the future -> add functionality then

        return (attn_output, )
Esempio n. 11
0
 def _onerow(m, x, i):
     return jnp.dot(m[i], lax.dynamic_slice(x, [i], [_W]))
Esempio n. 12
0
def skip_slice(inputs, output_width):
    """Slice in the time dimension, getting the last output_width elements"""
    skip_cut = inputs.shape[1] - output_width
    slice_sizes = [inputs.shape[0], output_width, inputs.shape[2]]
    return lax.dynamic_slice(inputs, (0, skip_cut, 0), slice_sizes)
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 mask=None,
                 custom_relative_position=None,
                 deterministic=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.
        Attention weights are masked out if their corresponding mask value
        is `False`.
      custom_relative_position: relative positions tensor
        `[batch_sizes..., query_length, key/value_length]'
      deterministic: if false, the attention weight is masked randomly
        using dropout, whereas if true, the attention weights
        are deterministic.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
            deterministic = module.merge_param('deterministic',
                                               self.deterministic,
                                               deterministic)
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = functools.partial(linear.DenseGeneral,
                                  axis=-1,
                                  features=(self.num_heads, head_dim),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  precision=self.precision)
        relative_attention_embed = linear.Embed(
            num_embeddings=self.num_relative_position_buckets,
            features=self.num_heads,
            embedding_init=initializers.normal(stddev=1.0),
            dtype=self.dtype)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                             dense(dtype=self.dtype, name='key')(inputs_kv),
                             dense(dtype=self.dtype, name='value')(inputs_kv))

        if custom_relative_position is None:
            query_length = inputs_q.shape[-2]
            key_length = inputs_kv.shape[-2]
            context_position = jnp.arange(query_length, dtype=jnp.int32)[:,
                                                                         None]
            memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :]

            relative_position = memory_position - context_position
            relative_position_bucket = make_relative_position_bucket(
                relative_position,
                bidirectional=self.bidirectional,
                num_buckets=self.num_relative_position_buckets,
                max_distance=self.max_distance)

            bias = relative_attention_embed(relative_position_bucket)
            bias = bias.transpose((2, 0, 1))
            # Expand batch dimensions.
            bias = jnp.broadcast_to(bias, (1, ) * len(inputs_q.shape[:-2]) +
                                    bias.shape)

        else:
            relative_position = custom_relative_position
            relative_position_bucket = make_relative_position_bucket(
                relative_position,
                bidirectional=self.bidirectional,
                num_buckets=self.num_relative_position_buckets,
                max_distance=self.max_distance)

            bias = relative_attention_embed(relative_position_bucket)
            permute = tuple(
                map(lambda i: len(inputs_q.shape) + 1 + i, (-1, -3, -2)))
            bias = bias.transpose(
                tuple(range(len(inputs_q.shape[:-2]))) + permute)

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape)
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads,
                                                      depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        'expected query shape %s instead got %s.' %
                        (expected_shape, query.shape))
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = attention.combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length)))

                bias = lax.dynamic_slice(bias, (0, 0, cur_index, 0),
                                         (1, self.num_heads, 1, max_length))

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            bias += lax.select(mask > 0,
                               jnp.full(mask.shape, 0.).astype(self.dtype),
                               jnp.full(mask.shape, -1e10).astype(self.dtype))

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # apply attention
        x = attention.dot_product_attention(
            query,
            key,
            value,
            bias=bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout_rate,
            broadcast_dropout=self.broadcast_dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=self.precision)  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = linear.DenseGeneral(features=features,
                                  axis=(-2, -1),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  dtype=self.dtype,
                                  precision=self.precision,
                                  name='out')(x)
        return out
Esempio n. 14
0
    def update_site(self, inputs: Array, index: int) -> Array:
        """
        Adds an input site into the cache, and applies the masked convolution to the cache.

        Args:
          inputs: an input site to be added into the cache with dimensions (batch, features).
          index: the index of the output site. The index of the input site should be `index - self.exclusive`.

        Returns:
          The next output site with dimensions (batch, features).
        """
        dtype = jnp.promote_types(inputs.dtype, self.dtype)

        inputs = jnp.asarray(inputs, dtype)

        L = self.L
        index_w = index % L

        kernel_h, kernel_w = self.kernel_size
        dilation_h, dilation_w = self.kernel_dilation
        ones = (1, 1)

        is_single_input = False
        if inputs.ndim == 1:
            is_single_input = True
            inputs = jnp.expand_dims(inputs, axis=0)

        batch, in_features = inputs.shape
        assert in_features % self.feature_group_count == 0
        recep_h = (kernel_h - 1) * dilation_h + 1
        recep_w = (kernel_w - 1) * dilation_w + 1

        # Initialize the cache with zeros, and the RNG key is None
        # `cache.dtype` must be the same as `inputs.dtype` (no promotion)
        _cache = self.variable(
            "cache",
            "inputs",
            zeros,
            None,
            (batch, recep_h, L, in_features),
            inputs.dtype,
        )

        initializing = self.is_mutable_collection("params")
        if not initializing:
            # Add the input site into the cache
            # To write the cache, use `_cache.value` as the left value of the assignment

            inputs = jnp.expand_dims(inputs, axis=(1, 2))

            # Index of the input site in the width direction
            index_w_in = (index - self.exclusive) % L

            def _add(cache):
                # return cache.at[:, -1, index_w_in, :].set(inputs)
                return lax.dynamic_update_slice(cache, inputs,
                                                (0, -1, index_w_in, 0))

            def _shift(cache):
                return jnp.concatenate(
                    [
                        cache[:, 1:, :, :],
                        jnp.zeros(
                            (batch, 1, L, in_features), dtype=inputs.dtype),
                    ],
                    axis=1,
                )

            cache_new_row = lax.cond(
                index_w_in == 0,
                lambda _: _add(_shift(_cache.value)),
                lambda _: _shift(_add(_cache.value)),
                None,
            )

            cache_new = lax.cond(
                index_w == 0,
                lambda _: cache_new_row,
                lambda _: _add(_cache.value),
                None,
            )

            _cache.value = lax.cond(
                index - self.exclusive >= 0,
                lambda _: cache_new,
                lambda _: _cache.value,
                None,
            )

        cache = _cache.value
        cache = jnp.asarray(cache, dtype)

        kernel_shape = self.kernel_size + (
            in_features // self.feature_group_count,
            self.features,
        )
        kernel = self.param(
            "kernel",
            wrap_kernel_init(self.kernel_init, self.mask),
            kernel_shape,
            self.dtype,
        )
        kernel = jnp.asarray(kernel, dtype)

        # Zero padding
        cache = jnp.pad(
            cache,
            (
                (0, 0),
                (0, 0),
                (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w),
                (0, 0),
            ),
        )

        # cache = cache[:, :, index_w : index_w + recep_w, :]
        cache = lax.dynamic_slice(cache, (0, 0, index_w, 0),
                                  (batch, recep_h, recep_w, in_features))

        dimension_numbers = flax.linen.linear._conv_dimension_numbers(
            cache.shape)
        y_i = lax.conv_general_dilated(
            cache,
            kernel,
            window_strides=ones,
            padding="VALID",
            lhs_dilation=ones,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=dimension_numbers,
            feature_group_count=self.feature_group_count,
            precision=self.precision,
        )

        if self.use_bias:
            bias = self.param("bias", self.bias_init, (self.features, ),
                              self.dtype)
            bias = jnp.asarray(bias, dtype)
            y_i = y_i + bias

        y_i = y_i.squeeze(axis=(1, 2))

        if is_single_input:
            y_i = y_i.squeeze(axis=0)

        return y_i
Esempio n. 15
0
    def __call__(
        self,
        hidden_states,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        if not is_cross_attention:
            qkv_out = self.c_attn(hidden_states)
            query, key, value = jnp.split(qkv_out, 3, axis=2)
        else:
            q_out = self.q_attn(hidden_states)
            (query, ) = jnp.split(q_out, 1, axis=2)
            kv_out = self.c_attn(key_value_states)
            key, value = jnp.split(kv_out, 2, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
def sliced_transposed_product(
        mat,
        block_size,
        axes=(-1, ),
        precision=lax.Precision.DEFAULT,
):
    """Returns the blocked slices representing a symmetric contraction.

  Specifically, the output is a contraction of the input mat with itself, in the
  specified axes.

  Args:
    mat: The matrix for which we will compute a contraction with itself.
    block_size: The size of row blocks to compute.
    axes: Axes to use for the contraction.
    precision: The precision to use in each computation.

  Raises:
    ValueError: Raised when the specified block size does not evenly divide
      the number of rows of the input mat.
  """
    rank = len(mat.shape)

    def _make_axis_positive(ax):
        assert -rank <= ax < rank
        return ax + rank if ax < 0 else ax

    positive_axes = [_make_axis_positive(ax) for ax in axes]
    assert len(positive_axes) == len(axes)
    remaining_axes = set(range(rank)) - set(positive_axes)
    assert len(remaining_axes) == 1
    remaining_ax = remaining_axes.pop()

    num_rows = mat.shape[remaining_ax]
    if num_rows % block_size != 0:
        raise ValueError(
            "The row dimension must be divisible by block_size. "
            f"Instead got row dimension={num_rows} and block_size={block_size}."
        )

    block_rows = []
    for i in range(num_rows // block_size):
        start_indices = [0] * rank
        start_indices[remaining_ax] = i * block_size

        slice_sizes = list(mat.shape)
        slice_sizes[remaining_ax] = block_size

        slice_sizes_full = list(mat.shape)
        slice_sizes_full[remaining_ax] = (i + 1) * block_size

        block_rows.append(
            product_with_transpose(
                lax.dynamic_slice(mat,
                                  start_indices=start_indices,
                                  slice_sizes=slice_sizes),
                lax.dynamic_slice(mat,
                                  start_indices=[0] * rank,
                                  slice_sizes=slice_sizes_full),
                axes=(axes, axes),
                precision=precision))

    return SlicedSymmetricMatrix(block_rows=block_rows)
Esempio n. 17
0
 def entry(xc, i):
     xcpad = jnp.pad(xc, pad_width=_W // 2)
     res = jnp.dot(m[i, :_W // 2], lax.dynamic_slice(xcpad, [i], [_W // 2]))
     entryi = b[i] - res
     xc = xc.at[i].set(entryi)
     return xc, None
Esempio n. 18
0
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # handle cache prepare causal attention mask
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[
                1]
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask)

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape,
                         jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
                                 value_states)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights
Esempio n. 19
0
    def apply(self,
              inputs,
              inputs_positions=None,
              min_timescale=1.0,
              max_timescale=10000.0,
              max_len=147,
              cache=None):
        """Adds positional embeddings to the inputs.

    Args:
      inputs: input data
      inputs_positions: input position indices for packed sequences.
      min_timescale: minimum scale that will be applied at each position
      max_timescale: maximum scale that will be applied at each position
      max_len: int: maximum length of sequence during eval.
      cache: flax attention cache for fast decoding.

    Returns:
      output: `(bs, timesteps, in_dim)`
    """
        assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                                  ' but it is: %d' % inputs.ndim)
        length = inputs.shape[1]
        channels = inputs.shape[2]
        num_timescales = channels // 2
        log_timescale_increment = (np.log(max_timescale / min_timescale) /
                                   (float(num_timescales) - 1))
        inv_timescales = min_timescale * np.exp(
            np.arange(num_timescales).astype('float32') *
            -log_timescale_increment)

        if inputs_positions is None:
            inputs_positions = np.expand_dims(
                np.arange(length).astype('float32'), 0)

        if cache:
            inputs_positions = np.expand_dims(
                np.arange(max_len).astype('float32'), 0)

        scaled_time = (jnp.expand_dims(inputs_positions.astype('float32'), 2) *
                       jnp.expand_dims(np.expand_dims(inv_timescales, 0), 0))
        signal = jnp.concatenate(
            [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2)
        signal = jnp.pad(signal, [[0, 0], [0, 0], [0, np.mod(channels, 2)]],
                         mode='constant',
                         constant_values=inputs.dtype.type(0))

        # We abuse the same attention Cache mechanism to run positional embeddings
        # in fast predict mode. We could use state variables instead, but this
        # simplifies invocation with a single top-level cache context manager.
        # We only use the cache's position index for tracking decoding position.
        if cache:
            if self.is_initializing():
                cache.store(lambda: (4, (1, 1)))
            else:
                cache_entry = cache.retrieve(None)
                i = cache_entry.i
                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one)
                cache.store(cache_entry)
                _, _, df = signal.shape
                signal = lax.dynamic_slice(signal, jnp.array((0, i, 0)),
                                           jnp.array((1, 1, df)))
        if cache:
            # just needed to set correct shape on init.
            return inputs + signal[:, :1, :]
        else:
            return inputs + signal
Esempio n. 20
0
    def beam_search_loop_body_fn(state):
        """Beam search loop state update function."""
        # Collect the current position slice along length to feed the fast
        # autoregressive decoder model.  Flatten the beam dimension into batch
        # dimension for feeding into the model.
        # --> [batch * beam, 1]
        flat_ids = flatten_beam_dim(
            lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
                              (batch_size, beam_size, 1)))
        # Flatten beam dimension into batch to be compatible with model.
        # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
        flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

        # Call fast-decoder model on current tokens to get next-position logits.
        # --> [batch * beam, vocab]
        flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

        # unflatten beam dimension
        # [batch * beam, vocab] --> [batch, beam, vocab]
        logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
        # Unflatten beam dimension in attention cache arrays
        # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
        new_cache = jax.tree_map(
            lambda x: unflatten_beam_dim(x, batch_size, beam_size),
            new_flat_cache)

        # Gather log probabilities from logits
        candidate_log_probs = jax.nn.log_softmax(logits)
        # Add new logprobs to existing prefix logprobs.
        # --> [batch, beam, vocab]
        log_probs = (candidate_log_probs +
                     jnp.expand_dims(state.live_logprobs, axis=2))

        # We'll need the vocab size, gather it from the log probability dimension.
        vocab_size = log_probs.shape[2]

        # Each item in batch has beam_size * vocab_size candidate sequences.
        # For each item, get the top 2*k candidates with the highest log-
        # probabilities. We gather the top 2*K beams here so that even if the best
        # K sequences reach EOS simultaneously, we have another K sequences
        # remaining to continue the live beam search.
        beams_to_keep = 2 * beam_size
        # Flatten beam and vocab dimensions.
        flat_log_probs = log_probs.reshape(
            (batch_size, beam_size * vocab_size))
        # Gather the top 2*K scores from _all_ beams.
        # --> [batch, 2*beams], [batch, 2*beams]
        topk_log_probs, topk_indices = lax.top_k(flat_log_probs,
                                                 k=beams_to_keep)
        # Recover the beam index by floor division.
        topk_beam_indices = topk_indices // vocab_size
        # Gather 2*k top beams.
        # --> [batch, 2*beams, length]
        topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
                                beams_to_keep)

        # Append the most probable 2*K token IDs to the top 2*K sequences
        # Recover token id by modulo division and expand Id array for broadcasting.
        # --> [batch, 2*beams, 1]
        topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
        # Update sequences for the 2*K top-k new sequences.
        # --> [batch, 2*beams, length]
        topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
                                            (0, 0, state.cur_index + 1))

        # Update LIVE (in-progress) sequences:
        # Did any of these sequences reach an end marker?
        # --> [batch, 2*beams]
        newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
        # To prevent these newly finished sequences from being added to the LIVE
        # set of active beam search sequences, set their log probs to a very large
        # negative value.
        new_log_probs = topk_log_probs + newly_finished * NEG_INF
        # Determine the top k beam indices (from top 2*k beams) from log probs.
        # --> [batch, beams]
        _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
        new_topk_indices = jnp.flip(new_topk_indices, axis=1)
        # Gather the top k beams (from top 2*k beams).
        # --> [batch, beams, length], [batch, beams]
        top_alive_seq, top_alive_log_probs = gather_beams(
            [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

        # Determine the top k beam indices from the original set of all beams.
        # --> [batch, beams]
        top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
                                         batch_size, beam_size)
        # With these, gather the top k beam-associated caches.
        # --> {[batch, beams, ...], ...}
        top_alive_cache = gather_beams(new_cache, top_alive_indices,
                                       batch_size, beam_size)

        # Update FINISHED (reached end of sentence) sequences:
        # Calculate new seq scores from log probabilities.
        new_scores = topk_log_probs / brevity_penalty(alpha,
                                                      state.cur_index + 1)
        # Mask out the still unfinished sequences by adding large negative value.
        # --> [batch, 2*beams]
        new_scores += (~newly_finished) * NEG_INF

        # Combine sequences, scores, and flags along the beam dimension and compare
        # new finished sequence scores to existing finished scores and select the
        # best from the new set of beams.
        finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
            [state.finished_seqs, topk_seq],
            axis=1)
        finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_scores, new_scores],
            axis=1)
        finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_flags, newly_finished],
            axis=1)
        # --> [batch, beams, length], [batch, beams], [batch, beams]
        top_finished_seq, top_finished_scores, top_finished_flags = (
            gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                              finished_scores, batch_size, beam_size))

        return BeamState(cur_index=state.cur_index + 1,
                         live_logprobs=top_alive_log_probs,
                         finished_scores=top_finished_scores,
                         live_seqs=top_alive_seq,
                         finished_seqs=top_finished_seq,
                         finished_flags=top_finished_flags,
                         cache=top_alive_cache)
Esempio n. 21
0
 def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
                          rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(shape, dtype)
   dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
   check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
Esempio n. 22
0
    def __call__(
        self,
        hidden_states,
        attention_mask,
        position_ids,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        sincos = jnp.take(self.embed_positions, position_ids, axis=0)
        sincos = jnp.split(sincos, 2, axis=-1)
        if self.rotary_dim is not None:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, sincos)
            q_rot = apply_rotary_pos_emb(q_rot, sincos)

            key = jnp.concatenate([k_rot, k_pass], axis=-1)
            query = jnp.concatenate([q_rot, q_pass], axis=-1)
        else:
            key = apply_rotary_pos_emb(key, sincos)
            query = apply_rotary_pos_emb(query, sincos)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Esempio n. 23
0
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None and input_ids.shape[-1] == 1:
                # if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
                cache_shift = flatten_dict(
                    unfreeze(past_key_values))[self._attn_layer_name +
                                               ("cache_index", )]
                position_ids = jnp.broadcast_to(
                    jnp.arange(self.config.max_position_embeddings)[None, :],
                    (batch_size, self.config.max_position_embeddings),
                )
                position_ids = lax.dynamic_slice(position_ids,
                                                 (0, cache_shift),
                                                 (batch_size, 1))
            else:
                position_ids = jnp.broadcast_to(
                    jnp.arange(sequence_length)[None, :],
                    (batch_size, sequence_length))

        if attention_mask is None:
            # if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
            if past_key_values is not None:
                cache_length = flatten_dict(
                    unfreeze(past_key_values))[self._attn_layer_name +
                                               ("cached_key", )].shape[1]
            else:
                cache_length = sequence_length

            # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation
            attention_mask = jnp.ones((batch_size, cache_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs