def beam_search_loop_body_fn(state): """Beam search loop state update function.""" # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1))) new_caches = [] log_probs_list = [] for i, model in enumerate(models): # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_map(flatten_beam_dim, state.cache[i]) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] flat_logits, new_flat_cache = tokens_to_logits(model, flat_ids, flat_cache, i) # unflatten beam dimension # [batch * beam, vocab] --> [batch, beam, vocab] current_logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(current_logits) log_probs_list.append(candidate_log_probs) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_caches.append( jax.tree_map(lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)) if len(models) > 1: if ens_type == 'arithmetic': candidate_log_probs = functools.reduce(lambda x, y: x + y, log_probs_list) candidate_log_probs /= len(models) elif ens_type == 'geometric': candidate_log_probs = functools.reduce(lambda x, y: x * y, log_probs_list) candidate_log_probs = -1 * ( jnp.abs(candidate_log_probs)**(1 / len(models))) else: raise ValueError('Unrecognized ens_type: %s' % ens_type) else: candidate_log_probs = log_probs_list[0] # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] log_probs = ( candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] # Each item in batch has beam_size * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. beams_to_keep = 2 * beam_size # Flatten beam and vocab dimensions. flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) # Gather the top 2*K scores from _all_ beams. # --> [batch, 2*beams], [batch, 2*beams] topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, beams_to_keep) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. # --> [batch, 2*beams, 1] topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, (0, 0, state.cur_index + 1)) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. new_log_probs = topk_log_probs + newly_finished * _NEG_INF # Determine the top k beam indices (from top 2*k beams) from log probs. # --> [batch, beams] _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, batch_size, beam_size) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} top_alive_caches = [] for i in range(len(new_caches)): top_alive_caches.append( gather_beams(new_caches[i], top_alive_indices, batch_size, beam_size)) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * _NEG_INF # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] [state.finished_seqs, topk_seq], axis=1) finished_scores = jnp.concatenate( # --> [batch, 3*beams] [state.finished_scores, new_scores], axis=1) finished_flags = jnp.concatenate( # --> [batch, 3*beams] [state.finished_flags, newly_finished], axis=1) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( gather_topk_beams([finished_seqs, finished_scores, finished_flags], finished_scores, batch_size, beam_size)) return BeamState( cur_index=state.cur_index + 1, live_logprobs=top_alive_log_probs, finished_scores=top_finished_scores, live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, cache=top_alive_caches)
def beam_search_body_fn(state, input_ids_length=1): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice( state.running_sequences, (0, 0, state.cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), )) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams ), model_outputs.past_key_values) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams(state.running_sequences, topk_beam_indices, batch_size, beams_to_keep) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state. cur_len] == eos_token_id running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array( -1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) beams_in_batch_are_full = (jnp.broadcast_to( state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate( [state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate( [state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map( lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation( model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)