Esempio n. 1
0
    def partition(arr, low, high):
        """
        Lomuto partition function.
        """
        if len(arr.shape) > 1:
            raise ValueError("Partition works on 1D arrays. Use vmap.")
        pivot = arr[high]
        def body(state):
            (j, i,arr) = state
            do_swap = arr[j] <= pivot
            i = jnp.where(do_swap, i+1, i)
            ai = arr[i, None]
            aj = arr[j, None]
            arr_swapped = dynamic_update_slice(arr, aj, [i])
            arr_swapped = dynamic_update_slice(arr_swapped, ai, [j])
            arr_swapped = jnp.where(do_swap, arr_swapped, arr)
            return (j + 1, i, arr_swapped)

        (j, i, arr) = while_loop(lambda state: state[0] < high,
                              body,(low, low - 1, arr))

        ai = arr[i+1, None]
        aj = arr[high, None]
        arr_swapped = dynamic_update_slice(arr, aj, [i+1])
        arr_swapped = dynamic_update_slice(arr_swapped, ai, [high])
        return (i + 1, arr_swapped)
Esempio n. 2
0
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slighly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # detect if we're initializing by absence of existing cache data.
        is_initialized = self.has_variable("cache", "cached_key")
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape,
                                   key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros,
                                     value.shape, value.dtype)
        cache_index = self.variable("cache", "cache_index",
                                    lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # update key, value caches with our new 1d spatial slices
            cur_index = cache_index.value
            indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value,
                                             indices)
            cached_key.value = key
            cached_value.value = value
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            attention_mask = combine_masks(pad_mask, attention_mask)
        return key, value, attention_mask
Esempio n. 3
0
 def _replace_result(operand, update1, update2):
     operand = dynamic_update_slice(
         operand, update1,
         jnp.asarray([i_lowest] + [0] * (len(operand.shape) - 1)))
     operand = dynamic_update_slice(
         operand, update2,
         jnp.asarray([splitting + 1] + [0] * (len(operand.shape) - 1)))
     return operand
Esempio n. 4
0
 def body(state):
     (j, i,arr) = state
     do_swap = arr[j] <= pivot
     i = jnp.where(do_swap, i+1, i)
     ai = arr[i, None]
     aj = arr[j, None]
     arr_swapped = dynamic_update_slice(arr, aj, [i])
     arr_swapped = dynamic_update_slice(arr_swapped, ai, [j])
     arr_swapped = jnp.where(do_swap, arr_swapped, arr)
     return (j + 1, i, arr_swapped)
Esempio n. 5
0
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = np.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
Esempio n. 6
0
 def sampling_loop_body_fn(state):
   """Sampling loop state update."""
   i, sequences, cache, cur_token, ended, rng = state
   # Split RNG for sampling.
   rng1, rng2 = random.split(rng)
   # Call fast-decoder model on current tokens to get next-position logits.
   logits, new_cache = tokens_to_logits(cur_token, cache)
   # Sample next token from logits.
   # TODO(levskaya): add top-p "nucleus" sampling option.
   if topk:
     # Get top-k logits and their indices, sample within these top-k tokens.
     topk_logits, topk_idxs = lax.top_k(logits, topk)
     topk_token = jnp.expand_dims(
         random.categorical(rng1, topk_logits / temperature).astype(jnp.int32),
         axis=-1)
     # Return the original indices corresponding to the sampled top-k tokens.
     next_token = jnp.squeeze(
         jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1)
   else:
     next_token = random.categorical(rng1,
                                     logits / temperature).astype(jnp.int32)
   # Only use sampled tokens if we're past provided prefix tokens.
   out_of_prompt = (sequences[:, i + 1] == 0)
   next_token = (
       next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt)
   # If end-marker reached for batch item, only emit padding tokens.
   next_token_or_endpad = next_token * ~ended
   ended |= (next_token_or_endpad == end_marker)
   # Add current sampled tokens to recorded sequences.
   new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad,
                                            (0, i + 1))
   return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
Esempio n. 7
0
        def greedy_search_body_fn(state):
            """state update fn."""
            model_outputs = model(state.running_token,
                                  params=params,
                                  **state.model_kwargs)
            logits = model_outputs.logits[:, -1]

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)

            next_token = jnp.argmax(logits, axis=-1)

            next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
            next_is_sent_finished = state.is_sent_finished | (next_token
                                                              == eos_token_id)
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences,
                                                      next_token,
                                                      (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)
            return GreedyState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                running_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Esempio n. 8
0
 def body(state: State):
     # N, K
     cluster_dist = vmap(lambda point: cluster_dist_metric(
         point, state.metric_state))(points)
     # N
     current_cluster_dist = vmap(lambda n, k: cluster_dist[n, k])(
         jnp.arange(N, dtype=jnp.int_), state.cluster_id)
     # N, K
     rel_dist = cluster_dist - current_cluster_dist[:, None]
     # N
     min_dist = jnp.min(rel_dist, axis=-1)
     proposed_cluster_id = jnp.argmin(rel_dist, axis=-1)
     can_take_from = state.metric_state.num_k[state.cluster_id] > D + 1
     min_dist = jnp.where(mask & can_take_from, min_dist, jnp.inf)
     amin = jnp.argmin(min_dist)
     k_to = proposed_cluster_id[amin]
     cluster_id = dynamic_update_slice(state.cluster_id, k_to[None],
                                       amin[None])
     # # update cluster_id
     # cluster_id = jnp.where(state.metric_state.num_k[state.cluster_id] < D+1, state.cluster_id, jnp.argmin(rel_dist, axis=-1))
     # proposed_num_k = jnp.bincount(proposed_cluster_id, weights, minlength=0, length=K)
     # cluster_id = jnp.where(proposed_num_k[proposed_cluster_id] < D + 1, state.cluster_id, proposed_cluster_id)
     metric_state = update_metric_state(cluster_id)
     # print()
     # print(state.i, jnp.sum(state.cluster_id!=cluster_id), amin, state.cluster_id[amin], k_to, jnp.min(rel_dist))
     done = jnp.all(cluster_id == state.cluster_id)
     state = state._replace(i=state.i + 1,
                            done=done,
                            cluster_id=cluster_id,
                            metric_state=metric_state)
     return state
Esempio n. 9
0
        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(state.running_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = model_outputs.logits[:, -1]

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
            # apply top_p, top_k, temperature
            logits = logits_warper(logits, logits, state.cur_len)

            next_token = jax.random.categorical(prng_key, logits, axis=-1)

            next_is_sent_finished = state.is_sent_finished | (next_token
                                                              == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences,
                                                      next_token,
                                                      (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                running_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                prng_key=prng_key_next,
            )
Esempio n. 10
0
    def prepare_inputs_for_generation(
            self,
            input_ids,
            max_length,
            attention_mask: Optional[jnp.DeviceArray] = None):
        # initializing the cache
        batch_size, seq_length = input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since GPTJ uses a causal mask, those positions are masked anyways.
        # Thus we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length),
                                           dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(
                jnp.arange(seq_length, dtype="i4")[None, :],
                (batch_size, seq_length))

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }
Esempio n. 11
0
 def loop_trans(j, coords):
     i = (n - j) - 1
     transformed_coords = extend(Triplet(*[di[i] for di in tris]), coords,
                                 False)
     return dynamic_update_slice(transformed_coords, coords_pretrans[i],
                                 [fs * i] + [0] *
                                 (transformed_coords.ndim - 1))
Esempio n. 12
0
 def body(state):
     (i, done, old_cluster_id, _, _, _, _, _, _, _, _, min_loss,
      delay) = state
     mask1 = mask & (old_cluster_id == 0)
     mask2 = mask & (old_cluster_id == 1)
     # estimate volumes of current clustering
     n1 = jnp.sum(mask1)
     n2 = jnp.sum(mask2)
     log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S)
     log_VS2 = log_VS + jnp.log(n2) - jnp.log(n_S)
     # construct E_1, E_2 and compute volumes
     mu1, C1 = bounding_ellipsoid(points, mask1)
     radii1, rotation1 = ellipsoid_params(C1)
     log_VE1 = log_ellipsoid_volume(radii1)
     mu2, C2 = bounding_ellipsoid(points, mask2)
     radii2, rotation2 = ellipsoid_params(C2)
     log_VE2 = log_ellipsoid_volume(radii2)
     # enlarge to at least cover V(S1) and V(S2)
     log_scale1 = log_coverage_scale(log_VE1, log_VS1, D)
     log_scale2 = log_coverage_scale(log_VE2, log_VS2, D)
     C1 = C1 / jnp.exp(log_scale1)
     radii1 = jnp.exp(jnp.log(radii1) + log_scale1)
     C2 = C2 / jnp.exp(log_scale2)
     radii2 = jnp.exp(jnp.log(radii2) + log_scale2)
     log_VE1 = log_VE1 + log_scale1 * D
     log_VE2 = log_VE2 + log_scale2 * D
     # compute reassignment metrics
     maha1 = vmap(lambda point: (point - mu1) @ C1 @ (point - mu1))(points)
     maha2 = vmap(lambda point: (point - mu2) @ C2 @ (point - mu2))(points)
     log_h1 = log_VE1 - log_VS1 + jnp.log(maha1)
     log_h2 = log_VE2 - log_VS2 + jnp.log(maha2)
     # reassign
     delta_F = jnp.exp(log_h1) - jnp.exp(log_h2)
     reassign_idx = jnp.argmax(jnp.abs(delta_F))
     new_cluster_id = dynamic_update_slice(
         cluster_id, (delta_F[reassign_idx, None] > 0).astype(jnp.int_),
         reassign_idx[None])
     # new_cluster_id = jnp.where(log_h1 < log_h2, 0, 1)
     log_V_sum = jnp.logaddexp(log_VE1, log_VE2)
     new_loss = jnp.exp(log_V_sum - log_VS)
     loss_decreased = new_loss < min_loss
     delay = jnp.where(loss_decreased, 0, delay + 1)
     min_loss = jnp.where(loss_decreased, new_loss, min_loss)
     ###
     # i / delay / loss_decreased / new_loss / min_loss
     # 0 / 0 / True / a / a
     # 1 / 1 / False / b / a
     # 2 / 2 / False / a / a
     # 3 / 3 / False / b / a
     # 4 / 4 / False / a / a
     done = jnp.all(new_cluster_id == old_cluster_id) \
            | (delay >= 10) \
            | (n1 < D + 1) \
            | (n2 < D + 1) \
            | jnp.isnan(log_V_sum)
     # print(i, "reassignments", jnp.sum(new_cluster_id != old_cluster_id), 'F', log_V_sum)
     # print(i, done, jnp.abs(delta_F).max())
     return (i + 1, done, new_cluster_id, log_VS1, mu1, radii1, rotation1,
             log_VS2, mu2, radii2, rotation2, min_loss, delay)
Esempio n. 13
0
        def generate(prompt_ids):
            def first_pass(prompt_ids):
                logits, cache = model(prompt_ids,
                                      past_key_values=past_key_values)[:2]
                next_token = jnp.argmax(logits[:, -1:], axis=-1)
                return next_token, cache

            def greedy_search_cond_fn(state):
                cur_len, _, _, _ = state
                return ~(cur_len == max_length - 1)

            def greedy_search_body_fn(state):
                cur_len, sequences, current_token, cache = state
                next_sequences = lax.dynamic_update_slice(
                    sequences, current_token, (0, cur_len))

                next_logits, next_cache = model(current_token,
                                                past_key_values=cache)[:2]
                next_token = jnp.argmax(next_logits, axis=-1)

                return cur_len + 1, next_sequences, next_token, next_cache

            # init tensor to be filled with generation result
            init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
            init_sequences = lax.dynamic_update_slice(init_sequences,
                                                      prompt_ids, (0, 0))

            # init past key values for cache
            past_key_values = model.init_cache(batch_size, max_length)

            # first pass with long prompt
            next_token, cache = first_pass(prompt_ids)

            # prepare state for generation loop
            init_state = (jnp.array(prompt_length), init_sequences, next_token,
                          cache)

            # fast generation
            _, output_sequences, final_token, _ = lax.while_loop(
                greedy_search_cond_fn, greedy_search_body_fn, init_state)

            # append last token
            output_sequences = lax.dynamic_update_slice(
                output_sequences, final_token, (0, max_length - 1))

            return output_sequences
Esempio n. 14
0
    def sampling_loop_body_fn(state):
        """Sampling loop state update."""
        i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state

        # Split RNG for sampling.
        rng1, rng2 = random.split(rng)

        # Call fast-decoder model on current tokens to get raw next-position logits.
        logits, new_cache, new_tokens_to_logits_state = tokens_to_logits(
            cur_token, cache, internal_state=tokens_to_logits_state)
        logits = logits / temperature

        # Mask out the BOS token.
        if masked_tokens is not None:
            mask = common_utils.onehot(jnp.array(masked_tokens),
                                       num_classes=logits.shape[-1],
                                       on_value=LARGE_NEGATIVE)
            mask = jnp.sum(mask,
                           axis=0)[None, :]  # Combine multiple masks together
            logits = logits + mask

        # Apply the repetition penalty.
        if repetition_penalty != 1:
            logits = apply_repetition_penalty(
                sequences,
                logits,
                i,
                repetition_penalty=repetition_penalty,
                repetition_window=repetition_window,
                repetition_penalty_normalize=repetition_penalty_normalize)

        # Mask out everything but the top-k entries.
        if top_k is not None:
            # Compute top_k_index and top_k_threshold with shapes (batch_size, 1).
            top_k_index = jnp.argsort(logits,
                                      axis=-1)[:, ::-1][:, top_k - 1:top_k]
            top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1)
            logits = jnp.where(logits < top_k_threshold,
                               jnp.full_like(logits, LARGE_NEGATIVE), logits)
        # Sample next token from logits.
        sample = multinomial(rng1, logits)
        next_token = sample.astype(jnp.int32)
        # Only use sampled tokens if we have past the out_of_prompt_marker.
        out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker)
        next_token = (next_token * out_of_prompt +
                      sequences[:, i + 1] * ~out_of_prompt)
        # If end-marker reached for batch item, only emit padding tokens.
        next_token = next_token[:, None]
        next_token_or_endpad = jnp.where(ended,
                                         jnp.full_like(next_token, pad_token),
                                         next_token)
        ended |= (next_token_or_endpad == end_marker)
        # Add current sampled tokens to recorded sequences.
        new_sequences = lax.dynamic_update_slice(sequences,
                                                 next_token_or_endpad,
                                                 (0, i + 1))
        return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended,
                rng2, new_tokens_to_logits_state)
Esempio n. 15
0
            def greedy_search_body_fn(state):
                cur_len, sequences, current_token, cache = state
                next_sequences = lax.dynamic_update_slice(
                    sequences, current_token, (0, cur_len))

                next_logits, next_cache = model(current_token,
                                                past_key_values=cache)[:2]
                next_token = jnp.argmax(next_logits, axis=-1)

                return cur_len + 1, next_sequences, next_token, next_cache
Esempio n. 16
0
def _matrix_put(ndarray, idx, val, block_size=1):
    """Similar to numpy.put using LAX operations."""
    idx_i, idx_j = idx
    sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
    slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)
    if not sli.step == slj.step == 1:
        raise TypeError("Non-unit step not supported in assigment.")

    if row_rev or col_rev:
        val = lax.rev(val, *onp.where([row_rev, col_rev]))

    start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
    return lax.dynamic_update_slice(ndarray, val, start_indices)
Esempio n. 17
0
def _update_slice(operand, update, start_indices, update_dims):
    """
  Similar to lax.dynamic_update_slice, but handles padded updates where padding
  values should not overwrite existing values in the array.

  Args:
  operand: the array to update
  update: the padded array to write
  start_indices: the offset at which to write `update`.
  update_dims: the true dimensions of the padded update `update`. Only values
    inside the rectangle given by `update_dims` will be overwritten."""
    operand_shape = operand.shape
    operand = lax.pad(operand, jnp.array(0, operand.dtype),
                      [(0, d, 0) for d in update.shape])
    start_indices = tuple(jnp.int32(i) for i in start_indices)
    t = lax.dynamic_slice(operand, start_indices, update.shape)
    t = _mask(update, update_dims, t)
    operand = lax.dynamic_update_slice(operand, t, start_indices)
    return lax.slice(operand, [0] * operand.ndim, operand_shape)
Esempio n. 18
0
 def body(state):
     (i, u, p, num_f, distance, num_bounce, results_array) = state
     # half step forward
     u = u + step_size * p
     # check violation
     C, C_grad = val_and_grad_from_U(u)
     n = C_grad / jnp.linalg.norm(C_grad)
     # bounce off contours and boundaries of
     p_bounce = jnp.where(((u > 1.) | (u < 0.)) | jnp.isnan(n), -p,
                          p - 2. * (p @ n) * n)
     # update momentum
     p = jnp.where(C >= 0., p, p_bounce)
     # update counters
     num_bounce = jnp.where(C >= 0., num_bounce, num_bounce + 1)
     distance = distance + step_size
     num_f = num_f + 1
     results_array = dynamic_update_slice(results_array, u[None, :],
                                          [i, 0])
     return (i + 1, u, p, num_f, distance, num_bounce, results_array)
        def greedy_search_body_fn(state):
            """state update fn."""
            model_outputs = model(state.current_token, **state.model_kwargs)
            next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)

            next_is_sent_finished = state.is_sent_finished | (next_token
                                                              == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences,
                                                      next_token,
                                                      (0, state.cur_len))
            next_model_kwargs = model.update_inputs_for_generation(
                model_outputs, model_kwargs)

            return GreedyState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                current_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Esempio n. 20
0
    def _sample(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jnp.ndarray] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length),
                             pad_token_id,
                             dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, ), dtype=jnp.bool_)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            input_ids, max_length, **model_kwargs)

        # initialize state
        state = SampleState(
            cur_len=cur_len,
            sequences=sequences,
            running_token=input_ids,
            is_sent_finished=is_sent_finished,
            prng_key=prng_key,
            model_kwargs=model_kwargs,
        )

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length,
                                               all_sequence_finished)
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(state.running_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = model_outputs.logits[:, -1]

            # apply min_length, ...
            logits = logits_processor(state.sequences, logits, state.cur_len)
            # apply top_p, top_k, temperature
            logits = logits_warper(logits, logits, state.cur_len)

            next_token = jax.random.categorical(prng_key, logits, axis=-1)

            next_is_sent_finished = state.is_sent_finished | (next_token
                                                              == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences,
                                                      next_token,
                                                      (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                running_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                prng_key=prng_key_next,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(sample_search_cond_fn,
                                            sample_search_body_fn, state)
        else:
            state = lax.while_loop(sample_search_cond_fn,
                                   sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)
Esempio n. 21
0
  def __call__(self,
               inputs_q,
               inputs_kv,
               padding_mask=None,
               key_padding_mask=None,
               segmentation=None,
               key_segmentation=None,
               decode=False):
    """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

    assert self.causal_mask or not self.decode, (
        'Caching is only support for causal attention.')

    if inputs_kv is None:
      inputs_kv = inputs_q

    attention_axis = self.attention_axis
    if self.attention_axis is None:
      attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = self.out_features or inputs_q.shape[-1]
    qkv_features = self.qkv_features or inputs_q.shape[-1]

    assert qkv_features % self.num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // self.num_heads

    dense = partial(DenseGeneral,
                    axis=-1,
                    features=(self.num_heads, head_dim),
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init,
                    use_bias=self.use_bias,
                    precision=self.precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                         dense(dtype=self.dtype, name='key')(inputs_kv),
                         dense(dtype=self.dtype, name='value')(inputs_kv))


    if self.decode:
      # detect if we're initializing by absence of existing cache data.
      is_initialized = self.has_variable('cache', 'cached_key')
      cached_key = self.variable('cache', 'cached_key',
                                 jnp.zeros, key.shape, key.dtype)
      cached_value = self.variable('cache', 'cached_value',
                                   jnp.zeros, value.shape, value.dtype)
      cache_index = self.variable('cache', 'cache_index',
                                  lambda: jnp.array(0, dtype=jnp.uint32))
      if is_initialized:
        expected_shape = list(cached_key.value.shape[:-2])
        for attn_dim in attention_axis:
          expected_shape[attn_dim] = 1
        expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
        if expected_shape != inputs_q.shape:
          raise ValueError('Invalid shape provided, '
                           'expected shape %s instead got %s.' %
                           (expected_shape, inputs_q.shape))

        cshape = cached_key.value.shape
        indices = [0] * len(cshape)
        i = cache_index.value
        attn_size = np.prod(np.take(cshape, attention_axis))
        for attn_dim in attention_axis:
          attn_size //= cshape[attn_dim]
          indices[attn_dim] = i // attn_size
          i = i % attn_size

        key = lax.dynamic_update_slice(cached_key.value, key, indices)
        value = lax.dynamic_update_slice(cached_value.value, value, indices)
        cached_key.value = key
        cached_value.value = value
        cache_index.value = cache_index.value + 1

        # TODO(levskaya): verify this is still needed in translation decoding.
        key_padding_mask = jnp.broadcast_to(
            (jnp.arange(cshape[1]) < cache_index.value), cshape[:2])
        key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]

    # create attention masks
    mask_components = []

    if self.causal_mask:
      if self.decode and is_initialized:
        bias_pre_shape = (1,) * (key.ndim - 1)
        attn_shape = tuple(np.take(key.shape, attention_axis))
        attn_size = np.prod(attn_shape)
        ii = jnp.arange(attn_size, dtype=jnp.uint32)
        mask = ii < cache_index.value
        mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
      else:
        mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
      if key_padding_mask is None:
        key_padding_mask = padding_mask
      padding_mask = make_padding_mask(
          padding_mask_query=padding_mask,
          padding_mask_key=key_padding_mask,
          query_shape=query.shape,
          key_shape=key.shape,
          attention_axis=attention_axis)
      mask_components.append(padding_mask)

    if segmentation is not None:
      if key_segmentation is None:
        key_segmentation = segmentation
      segmentation_mask = make_padding_mask(
          padding_mask_query=segmentation,
          padding_mask_key=key_segmentation,
          query_shape=query.shape,
          key_shape=key.shape,
          attention_axis=attention_axis,
          segmentation_mask=True)
      mask_components.append(segmentation_mask)

    if mask_components:
      attention_mask = mask_components[0]
      for component in mask_components[1:]:
        attention_mask = jnp.logical_and(attention_mask, component)

      # attention mask in the form of attention bias
      attention_bias = lax.select(
          attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(self.dtype),
          jnp.full(attention_mask.shape, -1e10).astype(self.dtype))
    else:
      attention_bias = None

    dropout_rng = None
    if not self.deterministic and self.dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')

    # apply attention
    x = self.attention_fn(
        query,
        key,
        value,
        dtype=self.dtype,
        axis=attention_axis,
        bias=attention_bias,
        precision=self.precision,
        dropout_rng=dropout_rng,
        dropout_rate=self.dropout_rate,
        broadcast_dropout=self.broadcast_dropout,
        deterministic=self.deterministic)

    # back to the original inputs dimensions
    out = DenseGeneral(features=features,
                       axis=(-2, -1),
                       kernel_init=self.kernel_init,
                       bias_init=self.bias_init,
                       use_bias=self.use_bias,
                       dtype=self.dtype,
                       precision=self.precision,
                       name='out')(x)

    return out
Esempio n. 22
0
    def beam_search_loop_body_fn(state):
        """Beam search loop state update function."""
        # Collect the current position slice along length to feed the fast
        # autoregressive decoder model.  Flatten the beam dimension into batch
        # dimension for feeding into the model.
        # --> [batch * beam, 1]
        flat_ids = flatten_beam_dim(
            lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
                              (batch_size, beam_size, 1)))
        # Flatten beam dimension into batch to be compatible with model.
        # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
        flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

        # Call fast-decoder model on current tokens to get next-position logits.
        # --> [batch * beam, vocab]
        flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

        # unflatten beam dimension
        # [batch * beam, vocab] --> [batch, beam, vocab]
        logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
        # Unflatten beam dimension in attention cache arrays
        # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
        new_cache = jax.tree_map(
            lambda x: unflatten_beam_dim(x, batch_size, beam_size),
            new_flat_cache)

        # Gather log probabilities from logits
        candidate_log_probs = jax.nn.log_softmax(logits)
        # Add new logprobs to existing prefix logprobs.
        # --> [batch, beam, vocab]
        log_probs = (candidate_log_probs +
                     jnp.expand_dims(state.live_logprobs, axis=2))

        # We'll need the vocab size, gather it from the log probability dimension.
        vocab_size = log_probs.shape[2]

        # Each item in batch has beam_size * vocab_size candidate sequences.
        # For each item, get the top 2*k candidates with the highest log-
        # probabilities. We gather the top 2*K beams here so that even if the best
        # K sequences reach EOS simultaneously, we have another K sequences
        # remaining to continue the live beam search.
        beams_to_keep = 2 * beam_size
        # Flatten beam and vocab dimensions.
        flat_log_probs = log_probs.reshape(
            (batch_size, beam_size * vocab_size))
        # Gather the top 2*K scores from _all_ beams.
        # --> [batch, 2*beams], [batch, 2*beams]
        topk_log_probs, topk_indices = lax.top_k(flat_log_probs,
                                                 k=beams_to_keep)
        # Recover the beam index by floor division.
        topk_beam_indices = topk_indices // vocab_size
        # Gather 2*k top beams.
        # --> [batch, 2*beams, length]
        topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
                                beams_to_keep)

        # Append the most probable 2*K token IDs to the top 2*K sequences
        # Recover token id by modulo division and expand Id array for broadcasting.
        # --> [batch, 2*beams, 1]
        topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
        # Update sequences for the 2*K top-k new sequences.
        # --> [batch, 2*beams, length]
        topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
                                            (0, 0, state.cur_index + 1))

        # Update LIVE (in-progress) sequences:
        # Did any of these sequences reach an end marker?
        # --> [batch, 2*beams]
        newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
        # To prevent these newly finished sequences from being added to the LIVE
        # set of active beam search sequences, set their log probs to a very large
        # negative value.
        new_log_probs = topk_log_probs + newly_finished * NEG_INF
        # Determine the top k beam indices (from top 2*k beams) from log probs.
        # --> [batch, beams]
        _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
        new_topk_indices = jnp.flip(new_topk_indices, axis=1)
        # Gather the top k beams (from top 2*k beams).
        # --> [batch, beams, length], [batch, beams]
        top_alive_seq, top_alive_log_probs = gather_beams(
            [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

        # Determine the top k beam indices from the original set of all beams.
        # --> [batch, beams]
        top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
                                         batch_size, beam_size)
        # With these, gather the top k beam-associated caches.
        # --> {[batch, beams, ...], ...}
        top_alive_cache = gather_beams(new_cache, top_alive_indices,
                                       batch_size, beam_size)

        # Update FINISHED (reached end of sentence) sequences:
        # Calculate new seq scores from log probabilities.
        new_scores = topk_log_probs / brevity_penalty(alpha,
                                                      state.cur_index + 1)
        # Mask out the still unfinished sequences by adding large negative value.
        # --> [batch, 2*beams]
        new_scores += (~newly_finished) * NEG_INF

        # Combine sequences, scores, and flags along the beam dimension and compare
        # new finished sequence scores to existing finished scores and select the
        # best from the new set of beams.
        finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
            [state.finished_seqs, topk_seq],
            axis=1)
        finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_scores, new_scores],
            axis=1)
        finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_flags, newly_finished],
            axis=1)
        # --> [batch, beams, length], [batch, beams], [batch, beams]
        top_finished_seq, top_finished_scores, top_finished_flags = (
            gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                              finished_scores, batch_size, beam_size))

        return BeamState(cur_index=state.cur_index + 1,
                         live_logprobs=top_alive_log_probs,
                         finished_scores=top_finished_scores,
                         live_seqs=top_alive_seq,
                         finished_seqs=top_finished_seq,
                         finished_flags=top_finished_flags,
                         cache=top_alive_cache)
Esempio n. 23
0
 def scalar_f2(x):
   return lax.dynamic_update_slice(x, 7, [])
Esempio n. 24
0
 def update_entry(arr, val, i, j):
     val = lax.reshape(val, [1, 1])
     return lax.dynamic_update_slice(arr, val, (i, j))
Esempio n. 25
0
        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Esempio n. 26
0
    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        """
        This beam search function is heavily inspired by Flax's official example:
        https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
        """
        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1], ) +
                                  tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            """
            Gathers the beam slices indexed by beam_indices into new beam array.
            """
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams,
                (batch_size, new_num_beams))

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                else:
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length),
                             pad_token_id,
                             dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length),
                                     pad_token_id,
                                     dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids,
                                                     (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(
            jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)),
            [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"][
                "last_hidden_state"] = flatten_beam_dim(
                    model_kwargs["encoder_outputs"]["last_hidden_state"])
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(
                model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(
            cur_len=cur_len,
            running_sequences=running_sequences,
            running_scores=running_scores,
            sequences=sequences,
            scores=scores,
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (
                max_length**length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished,
                jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7))
            improvement_still_possible = jnp.all(
                worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished)
                                & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[-1] > 1:
            state = partial(beam_search_body_fn,
                            input_ids_length=input_ids.shape[-1])(state)

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn,
                                            beam_search_body_fn, state)
        else:
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn,
                                   state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences,
                              state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores,
                           state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
Esempio n. 27
0
  def apply(self,
            inputs_q,
            inputs_kv,
            num_heads,
            dtype=jnp.float32,
            qkv_features=None,
            out_features=None,
            attention_axis=None,
            causal_mask=False,
            padding_mask=None,
            key_padding_mask=None,
            segmentation=None,
            key_segmentation=None,
            cache=None,
            broadcast_dropout=True,
            dropout_rng=None,
            dropout_rate=0.,
            deterministic=False,
            precision=None,
            kernel_init=nn.linear.default_kernel_init,
            bias_init=nn.initializers.zeros,
            bias=True,
            block_size=50,
            max_num_blocks=25,
            sort_activation='softmax'):
    """Applies multi-head sinkhorn attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      block_size: int, block size.
      max_num_blocks:  int, max num blocks.
      sort_activation: str {softmax, sinkhorn, gumbel_sinkhorn}

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    assert inputs_q.ndim == 3

    if inputs_kv is None:
      inputs_kv = inputs_q

    if attention_axis is None:
      attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = nn.DenseGeneral.partial(
        axis=-1,
        features=(num_heads, head_dim),
        kernel_init=kernel_init,
        bias_init=bias_init,
        bias=bias,
        precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    qlength = inputs_q.shape[-2]
    bs = inputs_q.shape[0]
    kvlength = inputs_kv.shape[-2]

    query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                         dense(inputs_kv, dtype=dtype, name='key'),
                         dense(inputs_kv, dtype=dtype, name='value'))

    if cache:
      assert isinstance(cache, Cache), 'cache must be an instance of Cache'
      if self.is_initializing():
        cache.store(onp.array((key.ndim,) + key.shape[-2:], dtype=onp.int32))
      else:
        cache_entry = cache.retrieve(None)
        expected_shape = list(cache_entry.key.shape[:-2])
        for attn_dim in attention_axis:
          expected_shape[attn_dim] = 1
        expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
        if expected_shape != inputs_q.shape:
          raise ValueError('Invalid shape provided, '
                           'expected shape %s instead got %s.' %
                           (expected_shape, inputs_q.shape))

        if not isinstance(cache_entry, _CacheEntry):
          raise ValueError('Cache is not initialized.')

        cshape = cache_entry.key.shape
        indices = [0] * len(cshape)
        i = cache_entry.i
        attn_size = onp.prod(onp.take(cshape, attention_axis))
        for attn_dim in attention_axis:
          attn_size //= cshape[attn_dim]
          indices[attn_dim] = i // attn_size
          i = i % attn_size

        key = lax.dynamic_update_slice(cache_entry.key, key, indices)
        value = lax.dynamic_update_slice(cache_entry.value, value, indices)
        one = jnp.array(1, jnp.uint32)
        cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                          key=key,
                                          value=value)
        cache.store(cache_entry)

        key_padding_mask = jnp.broadcast_to(
            (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
        key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]

    # block reshape before attention
    num_query_blocks = qlength // block_size
    num_kv_blocks = kvlength // block_size

    block_query = jnp.reshape(
        query, (bs, block_size, num_query_blocks, num_heads, head_dim))
    block_key = jnp.reshape(
        key, (bs, block_size, num_kv_blocks, num_heads, head_dim))
    block_value = jnp.reshape(
        value, (bs, block_size, num_kv_blocks, num_heads, head_dim))

    if causal_mask:
      # causal masking needs to not have blocks with mixed information.
      sum_key = jnp.cumsum(block_key, axis=1)
      sum_key = sum_key[:, 0, :, :, :]  # take first item
    else:
      sum_key = jnp.sum(block_key, axis=1)

    # sort net on head_dim dimensions
    sort_out = nn.DenseGeneral(sum_key, axis=-1,
                               features=(max_num_blocks),
                               kernel_init=kernel_init,
                               bias_init=bias_init,
                               bias=bias,
                               precision=precision)

    # (bs x num_key_blocks x num_heads x num_key_blocks
    sort_out = sort_out[:, :, :, :num_query_blocks]

    # simple softmax sorting first.

    if sort_activation == 'sinkhorn':
      permutation = sinkhorn_operator(
          jnp.reshape(sort_out, (-1, num_kv_blocks, num_query_blocks)),
          causal=causal_mask)
      permutation = jnp.reshape(permutation, (-1, num_kv_blocks, num_heads,
                                              num_query_blocks))
    else:
      if causal_mask:
        block_mask = _make_causal_mask(key, attention_axis)
        sort_out += block_mask
      permutation = jax.nn.softmax(sort_out, axis=-1)

    sorted_key = jnp.einsum('bskhd,bnhl->bsnhd', block_key, permutation)
    sorted_value = jnp.einsum('bskhd,bnhl->bsnhd', block_value, permutation)

    # create attention masks
    mask_components = []
    sorted_mask_components = []

    if causal_mask:
      # TODO(yitay): Test this causal masking.
      if cache and not self.is_initializing():
        bias_pre_shape = (1,) * (key.ndim - 1)
        attn_shape = tuple(onp.take(key.shape, attention_axis))
        attn_size = onp.prod(attn_shape)
        ii = jnp.arange(attn_size, dtype=jnp.uint32)
        mask = ii < cache_entry.i
        mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
      else:
        mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
      # divide padding mask into block
      padding_mask = jnp.reshape(padding_mask,
                                 (bs * num_query_blocks, block_size, 1))
      if key_padding_mask is None:
        key_padding_mask = padding_mask

      padding_mask = make_padding_mask(
          padding_mask_query=padding_mask,
          padding_mask_key=key_padding_mask,
          query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim),
          key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim),
          attention_axis=attention_axis)

      padding_mask = jnp.reshape(padding_mask,
                                 (bs, num_query_blocks, block_size, block_size))
      mask_components.append(padding_mask)
      sorted_padding_mask = jnp.einsum('bksj,bnhl->bnsj', padding_mask,
                                       permutation)
      sorted_mask_components.append(sorted_padding_mask)

    if segmentation is not None:
      if key_segmentation is None:
        key_segmentation = segmentation
      segmentation_mask = make_padding_mask(
          padding_mask_query=segmentation,
          padding_mask_key=key_segmentation,
          query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim),
          key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim),
          attention_axis=attention_axis,
          segmentation_mask=True)
      segmentation_mask = jnp.reshape(segmentation_mask,
                                      (bs, num_query_blocks, block_size,
                                       block_size))
      mask_components.append(segmentation_mask)
      sorted_segmentation_mask = jnp.einsum('bksj,bnhl->bnsj',
                                            segmentation_mask,
                                            permutation)
      sorted_mask_components.append(sorted_segmentation_mask)

    if mask_components:
      attention_mask = mask_components[0]
      for component in mask_components[1:]:
        attention_mask = jnp.logical_and(attention_mask, component)

      # attention mask in the form of attention bias
      attention_bias = lax.select(
          attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype),
          jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
      attention_bias = None

    if sorted_mask_components:
      attention_mask = sorted_mask_components[0]
      for component in sorted_mask_components[1:]:
        attention_mask = jnp.logical_and(attention_mask, component)

      # attention mask in the form of attention bias
      sorted_attention_bias = lax.select(
          attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype),
          jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
      sorted_attention_bias = None

    # apply attention
    x = local_dot_product_attention(
        block_query,
        block_key,
        block_value,
        dtype=dtype,
        axis=attention_axis,
        bias=attention_bias,
        precision=precision,
        dropout_rng=dropout_rng,
        dropout_rate=dropout_rate,
        broadcast_dropout=broadcast_dropout,
        deterministic=deterministic)

    sorted_x = local_dot_product_attention(
        block_query,
        sorted_key,
        sorted_value,
        dtype=dtype,
        axis=attention_axis,
        bias=sorted_attention_bias,
        precision=precision,
        dropout_rng=dropout_rng,
        dropout_rate=dropout_rate,
        broadcast_dropout=broadcast_dropout,
        deterministic=deterministic)

    x = x + sorted_x

    x = jnp.reshape(x, (bs, qlength, num_heads, head_dim))

    # back to the original inputs dimensions
    out = nn.DenseGeneral(
        x,
        features=features,
        axis=(-2, -1),
        kernel_init=kernel_init,
        bias_init=bias_init,
        bias=bias,
        dtype=dtype,
        precision=precision,
        name='out')

    return out
Esempio n. 28
0
    _make_harness("dot_general", "",
                  lambda lhs, rhs: lax.dot_general(lhs, rhs,
                                                   dimension_numbers=(((2,), (1,)), ((0,), (0,)))),
                  [RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("dynamic_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("dynamic_update_slice", "",
                  # x:shape: (b, 4)
                  lambda x: lax.dynamic_update_slice(x, x, (0, 0)),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "0",
                  lambda x: jnp.einsum("...i->...", x),
                  [RandArg((3, 4), _f32)],
                  poly_axes=[0]),

    _make_harness("einsum", "1",
                  lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y),
                  [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)],
                  poly_axes=[0, 0]),

    _make_harness("einsum", "2",
                  lambda x, y: jnp.einsum("...ij,jk->...ik", x, y),
Esempio n. 29
0
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = partial(DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_dim),
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init,
                        use_bias=self.use_bias,
                        precision=self.precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                             dense(dtype=self.dtype, name='key')(inputs_kv),
                             dense(dtype=self.dtype, name='value')(inputs_kv))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape)
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads,
                                                      depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        'expected query shape %s instead got %s.' %
                        (expected_shape, query.shape))
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length)))

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype))
        else:
            attention_bias = None

        dropout_rng = None
        if not self.deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # apply attention
        x = self.attention_fn(query,
                              key,
                              value,
                              bias=attention_bias,
                              dropout_rng=dropout_rng,
                              dropout_rate=self.dropout_rate,
                              broadcast_dropout=self.broadcast_dropout,
                              deterministic=self.deterministic,
                              dtype=self.dtype,
                              precision=self.precision)

        # back to the original inputs dimensions
        out = DenseGeneral(features=features,
                           axis=(-2, -1),
                           kernel_init=self.kernel_init,
                           bias_init=self.bias_init,
                           use_bias=self.use_bias,
                           dtype=self.dtype,
                           precision=self.precision,
                           name='out')(x)
        return out
Esempio n. 30
0
 def _add(cache):
     # return cache.at[:, -1, index_w_in, :].set(inputs)
     return lax.dynamic_update_slice(cache, inputs,
                                     (0, -1, index_w_in, 0))