def __call__(self, audio_signal: torch.Tensor, length: torch.Tensor): """Returns a list of hypotheses given an input batch of the encoder hidden embedding. Output token is generated auto-repressively. Args: encoder_output: A tensor of size (batch, features, timesteps). encoded_lengths: list of int representing the length of each sequence output sequence. Returns: packed list containing batch number of sentences (Hypotheses). """ with torch.no_grad(): # Apply optional preprocessing encoder_output, encoded_lengths = self.run_encoder( audio_signal=audio_signal, length=length) encoder_output = encoder_output.transpose([0, 2, 1]) # (B, T, D) logitlen = encoded_lengths inseq = encoder_output # [B, T, D] hypotheses, timestamps = self._greedy_decode(inseq, logitlen) # Pack the hypotheses results packed_result = [ rnnt_utils.Hypothesis(score=-1.0, y_sequence=[]) for _ in range(len(hypotheses)) ] for i in range(len(packed_result)): packed_result[i].y_sequence = torch.tensor(hypotheses[i], dtype=torch.long) packed_result[i].length = timestamps[i] del hypotheses return packed_result
def pack_hypotheses( hypotheses: List[List[int]], timesteps: List[List[int]], logitlen: torch.Tensor, alignments: Optional[List[List[int]]] = None, ) -> List[rnnt_utils.Hypothesis]: logitlen_cpu = logitlen.to("cpu") return [ rnnt_utils.Hypothesis( y_sequence=torch.tensor(sent, dtype=torch.long), score=-1.0, timestep=timestep, length=length, alignments=alignments[idx] if alignments is not None else None, ) for idx, ( sent, timestep, length) in enumerate(zip(hypotheses, timesteps, logitlen_cpu)) ]
def _greedy_decode_masked( self, x: torch.Tensor, out_len: torch.Tensor, device: torch.device, partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None, ): if partial_hypotheses is not None: raise NotImplementedError( "`partial_hypotheses` support is not supported") # x: [B, T, D] # out_len: [B] # device: torch.device # Initialize state batchsize = x.shape[0] hypotheses = [ rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize) ] # Initialize Hidden state matrix (shared by entire batch) hidden = None # If alignments need to be preserved, register a danling list to hold the values if self.preserve_alignments: # alignments is a 3-dimensional dangling list representing B x T x U for hyp in hypotheses: hyp.alignments = [[]] else: alignments = None # Last Label buffer + Last Label without blank buffer # batch level equivalent of the last_label last_label = torch.full([batchsize, 1], fill_value=self._blank_index, dtype=torch.long, device=device) last_label_without_blank = last_label.clone() # Mask buffers blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) # Get max sequence length max_out_len = out_len.max() for time_idx in range(max_out_len): f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] # Prepare t timestamp batch variables not_blank = True symbols_added = 0 # Reset blank mask blank_mask.mul_(False) # Update blank mask with time mask # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch) # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len blank_mask = time_idx >= out_len # Start inner loop while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # Batch prediction and joint network steps # If very first prediction step, submit SOS tag (blank) to pred_step. # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state if time_idx == 0 and symbols_added == 0 and hidden is None: g, hidden_prime = self._pred_step(self._SOS, hidden, batch_size=batchsize) else: # Set a dummy label for the blank value # This value will be overwritten by "blank" again the last label update below # This is done as vocabulary of prediction network does not contain "blank" token of RNNT last_label_without_blank_mask = last_label == self._blank_index last_label_without_blank[ last_label_without_blank_mask] = 0 # temp change of label last_label_without_blank[ ~last_label_without_blank_mask] = last_label[ ~last_label_without_blank_mask] # Perform batch step prediction of decoder, getting new states and scores ("g") g, hidden_prime = self._pred_step(last_label_without_blank, hidden, batch_size=batchsize) # Batched joint step - Output = [B, V + 1] logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :] if logp.dtype != torch.float32: logp = logp.float() # Get index k, of max prob for batch v, k = logp.max(1) del g # Update blank mask with current predicted blanks # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U) k_is_blank = k == self._blank_index blank_mask.bitwise_or_(k_is_blank) # If preserving alignments, check if sequence length of sample has been reached # before adding alignment if self.preserve_alignments: # Insert ids into last timestep per sample logp_vals = logp.to('cpu').max(1)[1] for batch_idx in range(batchsize): if time_idx < out_len[batch_idx]: hypotheses[batch_idx].alignments[-1].append( logp_vals[batch_idx]) del logp_vals del logp # If all samples predict / have predicted prior blanks, exit loop early # This is equivalent to if single sample predicted k if blank_mask.all(): not_blank = False # If preserving alignments, convert the current Uj alignments into a torch.Tensor # Then preserve U at current timestep Ti # Finally, forward the timestep history to Ti+1 for that sample # All of this should only be done iff the current time index <= sample-level AM length. # Otherwise ignore and move to next sample / next timestep. if self.preserve_alignments: # convert Ti-th logits into a torch array for batch_idx in range(batchsize): # this checks if current timestep <= sample-level AM length # If current timestep > sample-level AM length, no alignments will be added # Therefore the list of Uj alignments is empty here. if len(hypotheses[batch_idx].alignments[-1]) > 0: hypotheses[batch_idx].alignments.append( []) # blank buffer for next timestep else: # Collect batch indices where blanks occurred now/past blank_indices = (blank_mask == 1).nonzero(as_tuple=False) # Recover prior state for all samples which predicted blank now/past if hidden is not None: # LSTM has 2 states hidden_prime = self.decoder.batch_copy_states( hidden_prime, hidden, blank_indices) elif len(blank_indices) > 0 and hidden is None: # Reset state if there were some blank and other non-blank predictions in batch # Original state is filled with zeros so we just multiply # LSTM has 2 states hidden_prime = self.decoder.batch_copy_states( hidden_prime, None, blank_indices, value=0.0) # Recover prior predicted label for all samples which predicted blank now/past k[blank_indices] = last_label[blank_indices, 0] # Update new label and hidden state for next iteration last_label = k.view(-1, 1) hidden = hidden_prime # Update predicted labels, accounting for time mask # If blank was predicted even once, now or in the past, # Force the current predicted label to also be blank # This ensures that blanks propogate across all timesteps # once they have occured (normally stopping condition of sample level loop). for kidx, ki in enumerate(k): if blank_mask[kidx] == 0: hypotheses[kidx].y_sequence.append(ki) hypotheses[kidx].timestep.append(time_idx) hypotheses[kidx].score += float(v[kidx]) symbols_added += 1 # Remove trailing empty list of alignments at T_{am-len} x Uj if self.preserve_alignments: for batch_idx in range(batchsize): if len(hypotheses[batch_idx].alignments[-1]) == 0: del hypotheses[batch_idx].alignments[-1] # Preserve states for batch_idx in range(batchsize): hypotheses[batch_idx].dec_state = self.decoder.batch_select_state( hidden, batch_idx) return hypotheses
def _greedy_decode( self, x: torch.Tensor, out_len: torch.Tensor, partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None): # x: [T, 1, D] # out_len: [seq_len] # Initialize blank state and empty label set in Hypothesis hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[]) if partial_hypotheses is not None: if len(partial_hypotheses.y_sequence) > 0: hypothesis.y_sequence.append( partial_hypotheses.y_sequence[-1].cpu().numpy()) hypothesis.dec_state = self.decoder.batch_concat_states( [partial_hypotheses.dec_state]) hypothesis.dec_state = _states_to_device( hypothesis.dec_state, x.device) if self.preserve_alignments: # Alignments is a 2-dimensional dangling list representing T x U # alignments = [[]] hypothesis.alignments = [[]] # For timestep t in X_t for time_idx in range(out_len): # Extract encoder embedding at timestep t # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] f = x.narrow(dim=0, start=time_idx, length=1) # Setup exit flags and counter not_blank = True symbols_added = 0 # While blank is not predicted, or we dont run out of max symbols per timestep while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # In the first timestep, we initialize the network with RNNT Blank # In later timesteps, we provide previous predicted label as input. last_label = (self._SOS if (hypothesis.y_sequence == [] and hypothesis.dec_state is None) else hypothesis.y_sequence[-1]) # Perform prediction network and joint network steps. g, hidden_prime = self._pred_step(last_label, hypothesis.dec_state) logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :] del g # torch.max(0) op doesnt exist for FP 16. if logp.dtype != torch.float32: logp = logp.float() # get index k, of max prob v, k = logp.max(0) k = k.item( ) # K is the label at timestep t_s in inner loop, s >= 0. if self.preserve_alignments: # insert logits into last timestep hypothesis.alignments[-1].append(k) del logp # If blank token is predicted, exit inner loop, move onto next timestep t if k == self._blank_index: not_blank = False if self.preserve_alignments: # convert Ti-th logits into a torch array hypothesis.alignments.append( []) # blank buffer for next timestep else: # Append token to label set, update RNN state. hypothesis.y_sequence.append(k) hypothesis.score += float(v) hypothesis.timestep.append(time_idx) hypothesis.dec_state = hidden_prime # Increment token counter. symbols_added += 1 # Remove trailing empty list of Alignments if self.preserve_alignments: if len(hypothesis.alignments[-1]) == 0: del hypothesis.alignments[-1] # Unpack the hidden states hypothesis.dec_state = self.decoder.batch_select_state( hypothesis.dec_state, 0) # Remove the original input label if partial hypothesis was provided if partial_hypotheses is not None: hypothesis.y_sequence = hypothesis.y_sequence[1:] return hypothesis