def partition(arr, low, high): """ Lomuto partition function. """ if len(arr.shape) > 1: raise ValueError("Partition works on 1D arrays. Use vmap.") pivot = arr[high] def body(state): (j, i,arr) = state do_swap = arr[j] <= pivot i = jnp.where(do_swap, i+1, i) ai = arr[i, None] aj = arr[j, None] arr_swapped = dynamic_update_slice(arr, aj, [i]) arr_swapped = dynamic_update_slice(arr_swapped, ai, [j]) arr_swapped = jnp.where(do_swap, arr_swapped, arr) return (j + 1, i, arr_swapped) (j, i, arr) = while_loop(lambda state: state[0] < high, body,(low, low - 1, arr)) ai = arr[i+1, None] aj = arr[high, None] arr_swapped = dynamic_update_slice(arr, aj, [i+1]) arr_swapped = dynamic_update_slice(arr_swapped, ai, [high]) return (i + 1, arr_swapped)
def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask
def _replace_result(operand, update1, update2): operand = dynamic_update_slice( operand, update1, jnp.asarray([i_lowest] + [0] * (len(operand.shape) - 1))) operand = dynamic_update_slice( operand, update2, jnp.asarray([splitting + 1] + [0] * (len(operand.shape) - 1))) return operand
def body(state): (j, i,arr) = state do_swap = arr[j] <= pivot i = jnp.where(do_swap, i+1, i) ai = arr[i, None] aj = arr[j, None] arr_swapped = dynamic_update_slice(arr, aj, [i]) arr_swapped = dynamic_update_slice(arr_swapped, ai, [j]) arr_swapped = jnp.where(do_swap, arr_swapped, arr) return (j + 1, i, arr_swapped)
def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = np.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get next-position logits. logits, new_cache = tokens_to_logits(cur_token, cache) # Sample next token from logits. # TODO(levskaya): add top-p "nucleus" sampling option. if topk: # Get top-k logits and their indices, sample within these top-k tokens. topk_logits, topk_idxs = lax.top_k(logits, topk) topk_token = jnp.expand_dims( random.categorical(rng1, topk_logits / temperature).astype(jnp.int32), axis=-1) # Return the original indices corresponding to the sampled top-k tokens. next_token = jnp.squeeze( jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1) else: next_token = random.categorical(rng1, logits / temperature).astype(jnp.int32) # Only use sampled tokens if we're past provided prefix tokens. out_of_prompt = (sequences[:, i + 1] == 0) next_token = ( next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token_or_endpad = next_token * ~ended ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
def greedy_search_body_fn(state): """state update fn.""" model_outputs = model(state.running_token, params=params, **state.model_kwargs) logits = model_outputs.logits[:, -1] # apply min_length, ... logits = logits_processor(state.sequences, logits, state.cur_len) next_token = jnp.argmax(logits, axis=-1) next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return GreedyState( cur_len=state.cur_len + 1, sequences=next_sequences, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def body(state: State): # N, K cluster_dist = vmap(lambda point: cluster_dist_metric( point, state.metric_state))(points) # N current_cluster_dist = vmap(lambda n, k: cluster_dist[n, k])( jnp.arange(N, dtype=jnp.int_), state.cluster_id) # N, K rel_dist = cluster_dist - current_cluster_dist[:, None] # N min_dist = jnp.min(rel_dist, axis=-1) proposed_cluster_id = jnp.argmin(rel_dist, axis=-1) can_take_from = state.metric_state.num_k[state.cluster_id] > D + 1 min_dist = jnp.where(mask & can_take_from, min_dist, jnp.inf) amin = jnp.argmin(min_dist) k_to = proposed_cluster_id[amin] cluster_id = dynamic_update_slice(state.cluster_id, k_to[None], amin[None]) # # update cluster_id # cluster_id = jnp.where(state.metric_state.num_k[state.cluster_id] < D+1, state.cluster_id, jnp.argmin(rel_dist, axis=-1)) # proposed_num_k = jnp.bincount(proposed_cluster_id, weights, minlength=0, length=K) # cluster_id = jnp.where(proposed_num_k[proposed_cluster_id] < D + 1, state.cluster_id, proposed_cluster_id) metric_state = update_metric_state(cluster_id) # print() # print(state.i, jnp.sum(state.cluster_id!=cluster_id), amin, state.cluster_id[amin], k_to, jnp.min(rel_dist)) done = jnp.all(cluster_id == state.cluster_id) state = state._replace(i=state.i + 1, done=done, cluster_id=cluster_id, metric_state=metric_state) return state
def sample_search_body_fn(state): """state update fn.""" prng_key, prng_key_next = jax.random.split(state.prng_key) model_outputs = model(state.running_token, params=params, **state.model_kwargs) logits = model_outputs.logits[:, -1] # apply min_length, ... logits = logits_processor(state.sequences, logits, state.cur_len) # apply top_p, top_k, temperature logits = logits_warper(logits, logits, state.cur_len) next_token = jax.random.categorical(prng_key, logits, axis=-1) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return SampleState( cur_len=state.cur_len + 1, sequences=next_sequences, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, prng_key=prng_key_next, )
def prepare_inputs_for_generation( self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since GPTJ uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to( jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, "position_ids": position_ids, }
def loop_trans(j, coords): i = (n - j) - 1 transformed_coords = extend(Triplet(*[di[i] for di in tris]), coords, False) return dynamic_update_slice(transformed_coords, coords_pretrans[i], [fs * i] + [0] * (transformed_coords.ndim - 1))
def body(state): (i, done, old_cluster_id, _, _, _, _, _, _, _, _, min_loss, delay) = state mask1 = mask & (old_cluster_id == 0) mask2 = mask & (old_cluster_id == 1) # estimate volumes of current clustering n1 = jnp.sum(mask1) n2 = jnp.sum(mask2) log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S) log_VS2 = log_VS + jnp.log(n2) - jnp.log(n_S) # construct E_1, E_2 and compute volumes mu1, C1 = bounding_ellipsoid(points, mask1) radii1, rotation1 = ellipsoid_params(C1) log_VE1 = log_ellipsoid_volume(radii1) mu2, C2 = bounding_ellipsoid(points, mask2) radii2, rotation2 = ellipsoid_params(C2) log_VE2 = log_ellipsoid_volume(radii2) # enlarge to at least cover V(S1) and V(S2) log_scale1 = log_coverage_scale(log_VE1, log_VS1, D) log_scale2 = log_coverage_scale(log_VE2, log_VS2, D) C1 = C1 / jnp.exp(log_scale1) radii1 = jnp.exp(jnp.log(radii1) + log_scale1) C2 = C2 / jnp.exp(log_scale2) radii2 = jnp.exp(jnp.log(radii2) + log_scale2) log_VE1 = log_VE1 + log_scale1 * D log_VE2 = log_VE2 + log_scale2 * D # compute reassignment metrics maha1 = vmap(lambda point: (point - mu1) @ C1 @ (point - mu1))(points) maha2 = vmap(lambda point: (point - mu2) @ C2 @ (point - mu2))(points) log_h1 = log_VE1 - log_VS1 + jnp.log(maha1) log_h2 = log_VE2 - log_VS2 + jnp.log(maha2) # reassign delta_F = jnp.exp(log_h1) - jnp.exp(log_h2) reassign_idx = jnp.argmax(jnp.abs(delta_F)) new_cluster_id = dynamic_update_slice( cluster_id, (delta_F[reassign_idx, None] > 0).astype(jnp.int_), reassign_idx[None]) # new_cluster_id = jnp.where(log_h1 < log_h2, 0, 1) log_V_sum = jnp.logaddexp(log_VE1, log_VE2) new_loss = jnp.exp(log_V_sum - log_VS) loss_decreased = new_loss < min_loss delay = jnp.where(loss_decreased, 0, delay + 1) min_loss = jnp.where(loss_decreased, new_loss, min_loss) ### # i / delay / loss_decreased / new_loss / min_loss # 0 / 0 / True / a / a # 1 / 1 / False / b / a # 2 / 2 / False / a / a # 3 / 3 / False / b / a # 4 / 4 / False / a / a done = jnp.all(new_cluster_id == old_cluster_id) \ | (delay >= 10) \ | (n1 < D + 1) \ | (n2 < D + 1) \ | jnp.isnan(log_V_sum) # print(i, "reassignments", jnp.sum(new_cluster_id != old_cluster_id), 'F', log_V_sum) # print(i, done, jnp.abs(delta_F).max()) return (i + 1, done, new_cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, min_loss, delay)
def generate(prompt_ids): def first_pass(prompt_ids): logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2] next_token = jnp.argmax(logits[:, -1:], axis=-1) return next_token, cache def greedy_search_cond_fn(state): cur_len, _, _, _ = state return ~(cur_len == max_length - 1) def greedy_search_body_fn(state): cur_len, sequences, current_token, cache = state next_sequences = lax.dynamic_update_slice( sequences, current_token, (0, cur_len)) next_logits, next_cache = model(current_token, past_key_values=cache)[:2] next_token = jnp.argmax(next_logits, axis=-1) return cur_len + 1, next_sequences, next_token, next_cache # init tensor to be filled with generation result init_sequences = jnp.zeros((batch_size, max_length), dtype="i4") init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0)) # init past key values for cache past_key_values = model.init_cache(batch_size, max_length) # first pass with long prompt next_token, cache = first_pass(prompt_ids) # prepare state for generation loop init_state = (jnp.array(prompt_length), init_sequences, next_token, cache) # fast generation _, output_sequences, final_token, _ = lax.while_loop( greedy_search_cond_fn, greedy_search_body_fn, init_state) # append last token output_sequences = lax.dynamic_update_slice( output_sequences, final_token, (0, max_length - 1)) return output_sequences
def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng, tokens_to_logits_state = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get raw next-position logits. logits, new_cache, new_tokens_to_logits_state = tokens_to_logits( cur_token, cache, internal_state=tokens_to_logits_state) logits = logits / temperature # Mask out the BOS token. if masked_tokens is not None: mask = common_utils.onehot(jnp.array(masked_tokens), num_classes=logits.shape[-1], on_value=LARGE_NEGATIVE) mask = jnp.sum(mask, axis=0)[None, :] # Combine multiple masks together logits = logits + mask # Apply the repetition penalty. if repetition_penalty != 1: logits = apply_repetition_penalty( sequences, logits, i, repetition_penalty=repetition_penalty, repetition_window=repetition_window, repetition_penalty_normalize=repetition_penalty_normalize) # Mask out everything but the top-k entries. if top_k is not None: # Compute top_k_index and top_k_threshold with shapes (batch_size, 1). top_k_index = jnp.argsort(logits, axis=-1)[:, ::-1][:, top_k - 1:top_k] top_k_threshold = jnp.take_along_axis(logits, top_k_index, axis=-1) logits = jnp.where(logits < top_k_threshold, jnp.full_like(logits, LARGE_NEGATIVE), logits) # Sample next token from logits. sample = multinomial(rng1, logits) next_token = sample.astype(jnp.int32) # Only use sampled tokens if we have past the out_of_prompt_marker. out_of_prompt = (sequences[:, i + 1] == out_of_prompt_marker) next_token = (next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt) # If end-marker reached for batch item, only emit padding tokens. next_token = next_token[:, None] next_token_or_endpad = jnp.where(ended, jnp.full_like(next_token, pad_token), next_token) ended |= (next_token_or_endpad == end_marker) # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice(sequences, next_token_or_endpad, (0, i + 1)) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, new_tokens_to_logits_state)
def greedy_search_body_fn(state): cur_len, sequences, current_token, cache = state next_sequences = lax.dynamic_update_slice( sequences, current_token, (0, cur_len)) next_logits, next_cache = model(current_token, past_key_values=cache)[:2] next_token = jnp.argmax(next_logits, axis=-1) return cur_len + 1, next_sequences, next_token, next_cache
def _matrix_put(ndarray, idx, val, block_size=1): """Similar to numpy.put using LAX operations.""" idx_i, idx_j = idx sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size) slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size) if not sli.step == slj.step == 1: raise TypeError("Non-unit step not supported in assigment.") if row_rev or col_rev: val = lax.rev(val, *onp.where([row_rev, col_rev])) start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start] return lax.dynamic_update_slice(ndarray, val, start_indices)
def _update_slice(operand, update, start_indices, update_dims): """ Similar to lax.dynamic_update_slice, but handles padded updates where padding values should not overwrite existing values in the array. Args: operand: the array to update update: the padded array to write start_indices: the offset at which to write `update`. update_dims: the true dimensions of the padded update `update`. Only values inside the rectangle given by `update_dims` will be overwritten.""" operand_shape = operand.shape operand = lax.pad(operand, jnp.array(0, operand.dtype), [(0, d, 0) for d in update.shape]) start_indices = tuple(jnp.int32(i) for i in start_indices) t = lax.dynamic_slice(operand, start_indices, update.shape) t = _mask(update, update_dims, t) operand = lax.dynamic_update_slice(operand, t, start_indices) return lax.slice(operand, [0] * operand.ndim, operand_shape)
def body(state): (i, u, p, num_f, distance, num_bounce, results_array) = state # half step forward u = u + step_size * p # check violation C, C_grad = val_and_grad_from_U(u) n = C_grad / jnp.linalg.norm(C_grad) # bounce off contours and boundaries of p_bounce = jnp.where(((u > 1.) | (u < 0.)) | jnp.isnan(n), -p, p - 2. * (p @ n) * n) # update momentum p = jnp.where(C >= 0., p, p_bounce) # update counters num_bounce = jnp.where(C >= 0., num_bounce, num_bounce + 1) distance = distance + step_size num_f = num_f + 1 results_array = dynamic_update_slice(results_array, u[None, :], [i, 0]) return (i + 1, u, p, num_f, distance, num_bounce, results_array)
def greedy_search_body_fn(state): """state update fn.""" model_outputs = model(state.current_token, **state.model_kwargs) next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_model_kwargs = model.update_inputs_for_generation( model_outputs, model_kwargs) return GreedyState( cur_len=state.cur_len + 1, sequences=next_sequences, current_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def _sample( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, prng_key: Optional[jnp.ndarray] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) batch_size, cur_len = input_ids.shape eos_token_id = jnp.array(eos_token_id) pad_token_id = jnp.array(pad_token_id) cur_len = jnp.array(cur_len) # per batch-item holding current token in loop. sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) # per batch-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size, ), dtype=jnp.bool_) # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. model = self.decode if self.config.is_encoder_decoder else self # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation( input_ids, max_length, **model_kwargs) # initialize state state = SampleState( cur_len=cur_len, sequences=sequences, running_token=input_ids, is_sent_finished=is_sent_finished, prng_key=prng_key, model_kwargs=model_kwargs, ) def sample_search_cond_fn(state): """state termination condition fn.""" has_reached_max_length = state.cur_len == max_length all_sequence_finished = jnp.all(state.is_sent_finished) finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) return ~finish_generation def sample_search_body_fn(state): """state update fn.""" prng_key, prng_key_next = jax.random.split(state.prng_key) model_outputs = model(state.running_token, params=params, **state.model_kwargs) logits = model_outputs.logits[:, -1] # apply min_length, ... logits = logits_processor(state.sequences, logits, state.cur_len) # apply top_p, top_k, temperature logits = logits_warper(logits, logits, state.cur_len) next_token = jax.random.categorical(prng_key, logits, axis=-1) next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return SampleState( cur_len=state.cur_len + 1, sequences=next_sequences, running_token=next_token, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, prng_key=prng_key_next, ) # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU if input_ids.shape[1] > 1: state = sample_search_body_fn(state) if not trace: state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) else: state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) return FlaxSampleOutput(sequences=state.sequences)
def __call__(self, inputs_q, inputs_kv, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, decode=False): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert self.causal_mask or not self.decode, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q attention_axis = self.attention_axis if self.attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32)) if is_initialized: expected_shape = list(cached_key.value.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cached_key.value.shape indices = [0] * len(cshape) i = cache_index.value attn_size = np.prod(np.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_index.value), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # create attention masks mask_components = [] if self.causal_mask: if self.decode and is_initialized: bias_pre_shape = (1,) * (key.ndim - 1) attn_shape = tuple(np.take(key.shape, attention_axis)) attn_size = np.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_index.value mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(self.dtype), jnp.full(attention_mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not self.deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn( query, key, value, dtype=self.dtype, axis=attention_axis, bias=attention_bias, precision=self.precision, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=self.deterministic) # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def beam_search_loop_body_fn(state): """Beam search loop state update function.""" # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) # unflatten beam dimension # [batch * beam, vocab] --> [batch, beam, vocab] logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree_map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] log_probs = (candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] # Each item in batch has beam_size * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. beams_to_keep = 2 * beam_size # Flatten beam and vocab dimensions. flat_log_probs = log_probs.reshape( (batch_size, beam_size * vocab_size)) # Gather the top 2*K scores from _all_ beams. # --> [batch, 2*beams], [batch, 2*beams] topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, beams_to_keep) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. # --> [batch, 2*beams, 1] topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, (0, 0, state.cur_index + 1)) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. new_log_probs = topk_log_probs + newly_finished * NEG_INF # Determine the top k beam indices (from top 2*k beams) from log probs. # --> [batch, beams] _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] top_alive_seq, top_alive_log_probs = gather_beams( [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, batch_size, beam_size) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} top_alive_cache = gather_beams(new_cache, top_alive_indices, batch_size, beam_size) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] [state.finished_seqs, topk_seq], axis=1) finished_scores = jnp.concatenate( # --> [batch, 3*beams] [state.finished_scores, new_scores], axis=1) finished_flags = jnp.concatenate( # --> [batch, 3*beams] [state.finished_flags, newly_finished], axis=1) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( gather_topk_beams([finished_seqs, finished_scores, finished_flags], finished_scores, batch_size, beam_size)) return BeamState(cur_index=state.cur_index + 1, live_logprobs=top_alive_log_probs, finished_scores=top_finished_scores, live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, cache=top_alive_cache)
def scalar_f2(x): return lax.dynamic_update_slice(x, 7, [])
def update_entry(arr, val, i, j): val = lax.reshape(val, [1, 1]) return lax.dynamic_update_slice(arr, val, (i, j))
def beam_search_body_fn(state, input_ids_length=1): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice( state.running_sequences, (0, 0, state.cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), )) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams ), model_outputs.past_key_values) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams(state.running_sequences, topk_beam_indices, batch_size, beams_to_keep) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state. cur_len] == eos_token_id running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array( -1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) beams_in_batch_are_full = (jnp.broadcast_to( state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate( [state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate( [state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map( lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def _beam_search( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): """ This beam search function is heavily inspired by Flax's official example: https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 """ def flatten_beam_dim(tensor): """Flattens the first two dimensions of a non-scalar array.""" # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor return tensor.reshape((tensor.shape[0] * tensor.shape[1], ) + tensor.shape[2:]) def unflatten_beam_dim(tensor, batch_size, num_beams): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) def gather_beams(nested, beam_indices, batch_size, new_num_beams): """ Gathers the beam slices indexed by beam_indices into new beam array. """ batch_indices = jnp.reshape( jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)) def gather_fn(tensor): # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor else: return tensor[batch_indices, beam_indices] return jax.tree_map(gather_fn, nested) # init values max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping batch_size, num_beams, cur_len = input_ids.shape eos_token_id = jnp.array(eos_token_id) pad_token_id = jnp.array(pad_token_id) cur_len = jnp.array(cur_len) # per batch,beam-item holding current token in loop. sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) # per batch,beam-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) # per batch,beam-item score, logprobs running_scores = jnp.tile( jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. model = self.decode if self.config.is_encoder_decoder else self # flatten beam dim if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"][ "last_hidden_state"] = flatten_beam_dim( model_kwargs["encoder_outputs"]["last_hidden_state"]) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = flatten_beam_dim( model_kwargs["attention_mask"]) # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation( flatten_beam_dim(input_ids), max_length, **model_kwargs) # initialize state state = BeamSearchState( cur_len=cur_len, running_sequences=running_sequences, running_scores=running_scores, sequences=sequences, scores=scores, is_sent_finished=is_sent_finished, model_kwargs=model_kwargs, ) def beam_search_cond_fn(state): """beam search state termination condition fn.""" # 1. is less than max length? not_max_length_yet = state.cur_len < max_length # 2. can the new beams still improve? best_running_score = state.running_scores[:, -1:] / ( max_length**length_penalty) worst_finished_score = jnp.where( state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)) improvement_still_possible = jnp.all( worst_finished_score < best_running_score) # 3. is there still a beam that has not finished? still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) return not_max_length_yet & still_open_beam & improvement_still_possible def beam_search_body_fn(state, input_ids_length=1): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice( state.running_sequences, (0, 0, state.cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), )) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams ), model_outputs.past_key_values) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams(state.running_sequences, topk_beam_indices, batch_size, beams_to_keep) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state. cur_len] == eos_token_id running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array( -1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) beams_in_batch_are_full = (jnp.broadcast_to( state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate( [state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate( [state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map( lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, ) # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU if input_ids.shape[-1] > 1: state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) if not trace: state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) else: state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return running sequences for that batch item. none_finished = jnp.any(state.is_sent_finished, axis=1) sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) # take best beam for each batch sequences = sequences[:, -1] scores = scores[:, -1] return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, block_size=50, max_num_blocks=25, sort_activation='softmax'): """Applies multi-head sinkhorn attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. block_size: int, block size. max_num_blocks: int, max num blocks. sort_activation: str {softmax, sinkhorn, gumbel_sinkhorn} Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') assert inputs_q.ndim == 3 if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial( axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] qlength = inputs_q.shape[-2] bs = inputs_q.shape[0] kvlength = inputs_kv.shape[-2] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store(onp.array((key.ndim,) + key.shape[-2:], dtype=onp.int32)) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # block reshape before attention num_query_blocks = qlength // block_size num_kv_blocks = kvlength // block_size block_query = jnp.reshape( query, (bs, block_size, num_query_blocks, num_heads, head_dim)) block_key = jnp.reshape( key, (bs, block_size, num_kv_blocks, num_heads, head_dim)) block_value = jnp.reshape( value, (bs, block_size, num_kv_blocks, num_heads, head_dim)) if causal_mask: # causal masking needs to not have blocks with mixed information. sum_key = jnp.cumsum(block_key, axis=1) sum_key = sum_key[:, 0, :, :, :] # take first item else: sum_key = jnp.sum(block_key, axis=1) # sort net on head_dim dimensions sort_out = nn.DenseGeneral(sum_key, axis=-1, features=(max_num_blocks), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # (bs x num_key_blocks x num_heads x num_key_blocks sort_out = sort_out[:, :, :, :num_query_blocks] # simple softmax sorting first. if sort_activation == 'sinkhorn': permutation = sinkhorn_operator( jnp.reshape(sort_out, (-1, num_kv_blocks, num_query_blocks)), causal=causal_mask) permutation = jnp.reshape(permutation, (-1, num_kv_blocks, num_heads, num_query_blocks)) else: if causal_mask: block_mask = _make_causal_mask(key, attention_axis) sort_out += block_mask permutation = jax.nn.softmax(sort_out, axis=-1) sorted_key = jnp.einsum('bskhd,bnhl->bsnhd', block_key, permutation) sorted_value = jnp.einsum('bskhd,bnhl->bsnhd', block_value, permutation) # create attention masks mask_components = [] sorted_mask_components = [] if causal_mask: # TODO(yitay): Test this causal masking. if cache and not self.is_initializing(): bias_pre_shape = (1,) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: # divide padding mask into block padding_mask = jnp.reshape(padding_mask, (bs * num_query_blocks, block_size, 1)) if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim), key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim), attention_axis=attention_axis) padding_mask = jnp.reshape(padding_mask, (bs, num_query_blocks, block_size, block_size)) mask_components.append(padding_mask) sorted_padding_mask = jnp.einsum('bksj,bnhl->bnsj', padding_mask, permutation) sorted_mask_components.append(sorted_padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=(bs * num_query_blocks, block_size, num_heads, head_dim), key_shape=(bs * num_kv_blocks, block_size, num_heads, head_dim), attention_axis=attention_axis, segmentation_mask=True) segmentation_mask = jnp.reshape(segmentation_mask, (bs, num_query_blocks, block_size, block_size)) mask_components.append(segmentation_mask) sorted_segmentation_mask = jnp.einsum('bksj,bnhl->bnsj', segmentation_mask, permutation) sorted_mask_components.append(sorted_segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None if sorted_mask_components: attention_mask = sorted_mask_components[0] for component in sorted_mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias sorted_attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: sorted_attention_bias = None # apply attention x = local_dot_product_attention( block_query, block_key, block_value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) sorted_x = local_dot_product_attention( block_query, sorted_key, sorted_value, dtype=dtype, axis=attention_axis, bias=sorted_attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) x = x + sorted_x x = jnp.reshape(x, (bs, qlength, num_heads, head_dim)) # back to the original inputs dimensions out = nn.DenseGeneral( x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') return out
_make_harness("dot_general", "", lambda lhs, rhs: lax.dot_general(lhs, rhs, dimension_numbers=(((2,), (1,)), ((0,), (0,)))), [RandArg((3, 4, 4), _f32), RandArg((3, 4), _f32)], poly_axes=[0, 0]), _make_harness("dynamic_slice", "", # x:shape: (b, 4) lambda x: lax.dynamic_slice(x, (0, 1), (x.shape[0], 2)), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("dynamic_update_slice", "", # x:shape: (b, 4) lambda x: lax.dynamic_update_slice(x, x, (0, 0)), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("einsum", "0", lambda x: jnp.einsum("...i->...", x), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("einsum", "1", lambda x, y: jnp.einsum("...ij,...jk->...ik", x, y), [RandArg((3, 4, 5), _f32), RandArg((3, 5, 6), _f32)], poly_axes=[0, 0]), _make_harness("einsum", "2", lambda x, y: jnp.einsum("...ij,jk->...ik", x, y),
def __call__(self, inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Returns: output of shape `[batch_sizes..., length, features]`. """ features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not self.deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=self.deterministic, dtype=self.dtype, precision=self.precision) # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
def _add(cache): # return cache.at[:, -1, index_w_in, :].set(inputs) return lax.dynamic_update_slice(cache, inputs, (0, -1, index_w_in, 0))