def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: """Greedy search implementation for transducer. Generic case when beam size = 1. Results might differ slightly due to implementation details as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`. Args: h: Encoded speech features (1, T_max, D_enc) Returns: hyp: 1-best decoding results """ # Initialize zero state vectors dec_state = self.decoder.initialize_state(h) # Construct initial hypothesis hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state, timestep=[-1], length=encoded_lengths) cache = {} # Initialize state and first token y, state, _ = self.decoder.score_hypothesis(hyp, cache) for i in range(int(encoded_lengths)): hi = h[:, i:i + 1, :] # [1, 1, D] not_blank = True symbols_added = 0 while not_blank: ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] ytu = ytu[0, 0, 0, :] # [V + 1] # max() requires float if ytu.dtype != torch.float32: ytu = ytu.float() logp, pred = torch.max(ytu, dim=-1) # [1, 1] pred = pred.item() if pred == self.blank: not_blank = False else: # Update state and current sequence hyp.y_sequence.append(int(pred)) hyp.score += float(logp) hyp.dec_state = state hyp.timestep.append(i) # Compute next state and token y, state, _ = self.decoder.score_hypothesis(hyp, cache) symbols_added += 1 return [hyp]
def ctc_decoder_predictions_tensor( self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, return_hypotheses: bool = False, ) -> List[str]: """ Decodes a sequence of labels to words Args: predictions: A torch.Tensor of shape [Batch, Time] of integer indices that correspond to the index of some character in the label set. predictions_len: Optional tensor of length `Batch` which contains the integer lengths of the sequence in the padded `predictions` tensor. return_hypotheses: Bool flag whether to return just the decoding predictions of the model or a Hypothesis object that holds information such as the decoded `text`, the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available). May also contain the log-probabilities of the decoder (if this method is called via transcribe()) Returns: Either a list of str which represent the CTC decoded strings per sample, or a list of Hypothesis objects containing additional information. """ hypotheses = [] # Drop predictions to CPU prediction_cpu_tensor = predictions.long().cpu() # iterate over batch for ind in range(prediction_cpu_tensor.shape[self.batch_dim_index]): prediction = prediction_cpu_tensor[ind].detach().numpy().tolist() if predictions_len is not None: prediction = prediction[:predictions_len[ind]] # CTC decoding procedure decoded_prediction = [] previous = self.blank_id for p in prediction: if (p != previous or previous == self.blank_id) and p != self.blank_id: decoded_prediction.append(p) previous = p text = self.decode_tokens_to_str(decoded_prediction) if not return_hypotheses: hypothesis = text else: hypothesis = Hypothesis( y_sequence=None, score=-1.0, text=text, alignments=prediction, length=predictions_len[ind] if predictions_len is not None else 0, ) hypotheses.append(hypothesis) return hypotheses
def test_RNNTDecoder(self): vocab = list(range(10)) vocab = [str(x) for x in vocab] vocab_size = len(vocab) pred_config = OmegaConf.create({ '_target_': 'nemo.collections.asr.modules.RNNTDecoder', 'prednet': { 'pred_hidden': 32, 'pred_rnn_layers': 1, }, 'vocab_size': vocab_size, 'blank_as_pad': True, }) prednet = modules.RNNTDecoder.from_config_dict(pred_config) # num params pred_hidden = pred_config.prednet.pred_hidden embed = (vocab_size + 1) * pred_hidden # embedding with blank rnn = (2 * 4 * (pred_hidden * pred_hidden + pred_hidden) ) # (ih + hh) * (ifco gates) * (indim * hiddendim + bias) assert prednet.num_weights == (embed + rnn) # State initialization x_ = torch.zeros(4, dtype=torch.float32) states = prednet.initialize_state(x_) for state_i in states: assert state_i.dtype == x_.dtype assert state_i.device == x_.device assert state_i.shape[1] == len(x_) # Blank hypotheses test blank = vocab_size hyp = Hypothesis(score=0.0, y_sequence=[blank]) cache = {} pred, states, _ = prednet.score_hypothesis(hyp, cache) assert pred.shape == torch.Size([1, 1, pred_hidden]) assert len(states) == 2 for state_i in states: assert state_i.dtype == pred.dtype assert state_i.device == pred.device assert state_i.shape[1] == len(pred) # Blank stateless predict g, states = prednet.predict(y=None, state=None, add_sos=False, batch_size=1) assert g.shape == torch.Size([1, 1, pred_hidden]) assert len(states) == 2 for state_i in states: assert state_i.dtype == g.dtype assert state_i.device == g.device assert state_i.shape[1] == len(g) # Blank stateful predict g, states2 = prednet.predict(y=None, state=states, add_sos=False, batch_size=1) assert g.shape == torch.Size([1, 1, pred_hidden]) assert len(states2) == 2 for state_i, state_j in zip(states, states2): assert (state_i - state_j).square().sum().sqrt() > 0.0 # Predict with token and state token = torch.full([1, 1], fill_value=0, dtype=torch.long) g, states = prednet.predict(y=token, state=states2, add_sos=False, batch_size=None) assert g.shape == torch.Size([1, 1, pred_hidden]) assert len(states) == 2 # Predict with blank token and no state token = torch.full([1, 1], fill_value=blank, dtype=torch.long) g, states = prednet.predict(y=token, state=None, add_sos=False, batch_size=None) assert g.shape == torch.Size([1, 1, pred_hidden]) assert len(states) == 2
def align_length_sync_decoding( self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (1, T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ # Precompute some constants for blank position ids = list(range(self.vocab_size + 1)) ids.remove(self.blank) # Used when blank token is first vs last token if self.blank == 0: index_incr = 1 else: index_incr = 0 # prepare the batched beam states beam = min(self.beam_size, self.vocab_size) h = h[0] # [T, D] h_length = int(encoded_lengths) beam_state = self.decoder.initialize_state( torch.zeros(beam, device=h.device, dtype=h.dtype)) # [L, B, H], [L, B, H] for LSTMS # compute u_max as either a specific static limit, # or a multiple of current `h_length` dynamically. if type(self.alsd_max_target_length) == float: u_max = int(self.alsd_max_target_length * h_length) else: u_max = int(self.alsd_max_target_length) # Initialize first hypothesis for the beam (blank) B = [ Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state( beam_state, 0)) ] final = [] cache = {} # ALSD runs for T + U_max steps for i in range(h_length + u_max): # Update caches A = [] B_ = [] h_states = [] # preserve the list of batch indices which are added into the list # and those which are removed from the list # This is necessary to perform state updates in the correct batch indices later batch_ids = list(range( len(B))) # initialize as a list of all batch ids batch_removal_ids = [] # update with sample ids which are removed for bid, hyp in enumerate(B): u = len(hyp.y_sequence) - 1 t = i - u + 1 if t > (h_length - 1): batch_removal_ids.append(bid) continue B_.append(hyp) h_states.append((t, h[t])) if B_: # Compute the subset of batch ids which were *not* removed from the list above sub_batch_ids = None if len(B_) != beam: sub_batch_ids = batch_ids for id in batch_removal_ids: # sub_batch_ids contains list of ids *that were not removed* sub_batch_ids.remove(id) # extract the states of the sub batch only. beam_state_ = [ beam_state[state_id][:, sub_batch_ids, :] for state_id in range(len(beam_state)) ] else: # If entire batch was used (none were removed), simply take all the states beam_state_ = beam_state # Decode a batch/sub-batch of beam states and scores beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis( B_, cache, beam_state_) # If only a subset of batch ids were updated (some were removed) if sub_batch_ids is not None: # For each state in the RNN (2 for LSTM) for state_id in range(len(beam_state)): # Update the current batch states with the sub-batch states (in the correct indices) # These indices are specified by sub_batch_ids, the ids of samples which were updated. beam_state[state_id][:, sub_batch_ids, :] = beam_state_[ state_id][...] else: # If entire batch was updated, simply update all the states beam_state = beam_state_ # h_states = list of [t, h[t]] # so h[1] here is a h[t] of shape [D] # Simply stack all of the h[t] within the sub_batch/batch (T <= beam) h_enc = torch.stack([h[1] for h in h_states]) # [T=beam, D] h_enc = h_enc.unsqueeze( 1) # [B=beam, T=1, D]; batch over the beams # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B=beam, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B=beam, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) for j, hyp in enumerate(B_): # For all updated samples in the batch, add it as the blank token # In this step, we dont add a token but simply update score new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[j, self.blank])), y_sequence=hyp.y_sequence[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) # Add blank prediction to A A.append(new_hyp) # If the prediction "timestep" t has reached the length of the input sequence # we can add it to the "finished" hypothesis list. if h_states[j][0] == (h_length - 1): final.append(new_hyp) # Here, we carefully select the indices of the states that we want to preserve # for the next token (non-blank) update. if sub_batch_ids is not None: h_states_idx = sub_batch_ids[j] else: h_states_idx = j # for each current hypothesis j # extract the top token score and top token id for the jth hypothesis for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): # create new hypothesis and store in A # Note: This loop does *not* include the blank token! new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence[:] + [int(k)]), dec_state=self.decoder.batch_select_state( beam_state, h_states_idx), lm_state=hyp.lm_state, ) A.append(new_hyp) # Prune and recombine same hypothesis # This may cause next beam to be smaller than max beam size # Therefore larger beam sizes may be required for better decoding. B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = self.recombine_hypotheses(B) # If B_ is empty list, then we may be able to early exit elif len(batch_ids) == len(batch_removal_ids): break if final: return self.sort_nbest(final) else: return B
def time_sync_decoding(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (1, T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ # Precompute some constants for blank position ids = list(range(self.vocab_size + 1)) ids.remove(self.blank) # Used when blank token is first vs last token if self.blank == 0: index_incr = 1 else: index_incr = 0 # prepare the batched beam states beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.initialize_state( torch.zeros(beam, device=h.device, dtype=h.dtype)) # [L, B, H], [L, B, H] (for LSTMs) # Initialize first hypothesis for the beam (blank) B = [ Hypothesis(y_sequence=[self.blank], score=0.0, dec_state=self.decoder.batch_select_state( beam_state, 0)) ] cache = {} for i in range(int(encoded_lengths)): hi = h[:, i:i + 1, :] # Update caches A = [] C = B h_enc = hi # For a limited number of symmetric expansions per timestep "i" for v in range(self.tsd_max_symmetric_expansion_per_step): D = [] # Decode a batch of beam states and scores beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis( C, cache, beam_state) # Extract the log probabilities and the predicted tokens beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y), dim=-1) # [B, 1, 1, V + 1] beam_logp = beam_logp[:, 0, 0, :] # [B, V + 1] beam_topk = beam_logp[:, ids].topk(beam, dim=-1) seq_A = [h.y_sequence for h in A] for j, hyp in enumerate(C): # create a new hypothesis in A if hyp.y_sequence not in seq_A: # If the sequence is not in seq_A, add it as the blank token # In this step, we dont add a token but simply update score A.append( Hypothesis( score=(hyp.score + float(beam_logp[j, self.blank])), y_sequence=hyp.y_sequence[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, )) else: # merge the existing blank hypothesis score with current score. dict_pos = seq_A.index(hyp.y_sequence) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[j, self.blank]))) if v < self.tsd_max_symmetric_expansion_per_step: for j, hyp in enumerate(C): # for each current hypothesis j # extract the top token score and top token id for the jth hypothesis for logp, k in zip(beam_topk[0][j], beam_topk[1][j] + index_incr): # create new hypothesis and store in D # Note: This loop does *not* include the blank token! new_hyp = Hypothesis( score=(hyp.score + float(logp)), y_sequence=(hyp.y_sequence + [int(k)]), dec_state=self.decoder.batch_select_state( beam_state, j), lm_state=hyp.lm_state, ) D.append(new_hyp) # Prune beam C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] # Prune beam B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
def default_beam_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: """Beam search implementation. Args: x: Encoded speech features (1, T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ # Initialize states beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) blank_tensor = torch.tensor([self.blank], device=h.device, dtype=torch.long) # Precompute some constants for blank position ids = list(range(self.vocab_size + 1)) ids.remove(self.blank) # Used when blank token is first vs last token if self.blank == 0: index_incr = 1 else: index_incr = 0 # Initialize zero vector states dec_state = self.decoder.initialize_state(h) # Initialize first hypothesis for the beam (blank) kept_hyps = [ Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state) ] cache = {} for i in range(int(encoded_lengths)): hi = h[:, i:i + 1, :] # [1, 1, D] hyps = kept_hyps kept_hyps = [] while True: max_hyp = max(hyps, key=lambda x: x.score) hyps.remove(max_hyp) # update decoder state and get next score y, state, lm_tokens = self.decoder.score_hypothesis( max_hyp, cache) # [1, 1, D] # get next token ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] ytu = ytu[0, 0, 0, :] # [V + 1] # remove blank token before top k top_k = ytu[ids].topk(beam_k, dim=-1) # Two possible steps - blank token or non-blank token predicted ytu = ( torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))), torch.cat((top_k[1] + index_incr, blank_tensor)), ) # for each possible step for logp, k in zip(*ytu): # construct hypothesis for step new_hyp = Hypothesis( score=(max_hyp.score + float(logp)), y_sequence=max_hyp.y_sequence[:], dec_state=max_hyp.dec_state, lm_state=max_hyp.lm_state, ) # if current token is blank, dont update sequence, just store the current hypothesis if k == self.blank: kept_hyps.append(new_hyp) else: # if non-blank token was predicted, update state and sequence and then search more hypothesis new_hyp.dec_state = state new_hyp.y_sequence.append(int(k)) hyps.append(new_hyp) # keep those hypothesis that have scores greater than next search generation hyps_max = float(max(hyps, key=lambda x: x.score).score) kept_most_prob = sorted( [hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score, ) # If enough hypothesis have scores greater than next search generation, # stop beam search. if len(kept_most_prob) >= beam: kept_hyps = kept_most_prob break return self.sort_nbest(kept_hyps)
def greedy_search(self, h: torch.Tensor, encoded_lengths: torch.Tensor) -> List[Hypothesis]: """Greedy search implementation for transducer. Generic case when beam size = 1. Results might differ slightly due to implementation details as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`. Args: h: Encoded speech features (1, T_max, D_enc) Returns: hyp: 1-best decoding results """ if self.preserve_alignments: # Alignments is a 2-dimensional dangling list representing T x U alignments = [[]] else: alignments = None # Initialize zero state vectors dec_state = self.decoder.initialize_state(h) # Construct initial hypothesis hyp = Hypothesis(score=0.0, y_sequence=[self.blank], dec_state=dec_state, timestep=[-1], length=encoded_lengths) cache = {} # Initialize state and first token y, state, _ = self.decoder.score_hypothesis(hyp, cache) for i in range(int(encoded_lengths)): hi = h[:, i:i + 1, :] # [1, 1, D] not_blank = True symbols_added = 0 while not_blank: ytu = torch.log_softmax(self.joint.joint(hi, y), dim=-1) # [1, 1, 1, V + 1] ytu = ytu[0, 0, 0, :] # [V + 1] # max() requires float if ytu.dtype != torch.float32: ytu = ytu.float() logp, pred = torch.max(ytu, dim=-1) # [1, 1] pred = pred.item() if self.preserve_alignments: # insert logits into last timestep alignments[-1].append(pred) if pred == self.blank: not_blank = False if self.preserve_alignments: # convert Ti-th logits into a torch array alignments.append([]) # blank buffer for next timestep else: # Update state and current sequence hyp.y_sequence.append(int(pred)) hyp.score += float(logp) hyp.dec_state = state hyp.timestep.append(i) # Compute next state and token y, state, _ = self.decoder.score_hypothesis(hyp, cache) symbols_added += 1 # Remove trailing empty list of alignments if self.preserve_alignments: if len(alignments[-1]) == 0: del alignments[-1] # attach alignments to hypothesis hyp.alignments = alignments return [hyp]