def forward(self, targets, target_length, states=None): # y: (B, U) y = rnn.label_collate(targets) # state maintenance is unnecessary during training forward call # to get state, use .predict() method. g, _ = self.predict(y, state=states, add_sos=True) # (B, U, D) g = g.transpose(1, 2) # (B, D, U) return g, target_length
def _pred_step( self, label: Union[torch.Tensor, int], hidden: Optional[torch.Tensor], add_sos: bool = False, batch_size: Optional[int] = None, ) -> (torch.Tensor, torch.Tensor): """ Common prediction step based on the AbstractRNNTDecoder implementation. Args: label: (int/torch.Tensor): Label or "Start-of-Signal" token. hidden: (Optional torch.Tensor): RNN State vector add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token. batch_size: Batch size of the output tensor. Returns: g: (B, U, H) if add_sos is false, else (B, U + 1, H) hid: (h, c) where h is the final sequence hidden state and c is the final cell state: h (tensor), shape (L, B, H) c (tensor), shape (L, B, H) """ if isinstance(label, torch.Tensor): # label: [batch, 1] if label.dtype != torch.long: label = label.long() else: # Label is an integer if label == self._SOS: return self.decoder.predict(None, hidden, add_sos=add_sos, batch_size=batch_size) label = label_collate([[label]]) # output: [B, 1, K] return self.decoder.predict(label, hidden, add_sos=add_sos, batch_size=batch_size)
def _greedy_decode( self, x: torch.Tensor, out_len: torch.Tensor, partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None): # x: [T, 1, D] # out_len: [seq_len] # Initialize blank state and empty label set in Hypothesis hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) if partial_hypotheses is not None: hypothesis.last_token = partial_hypotheses.last_token if partial_hypotheses.dec_state is not None: hypothesis.dec_state = self.decoder.batch_concat_states( [partial_hypotheses.dec_state]) hypothesis.dec_state = _states_to_device( hypothesis.dec_state, x.device) if self.preserve_alignments: # Alignments is a 2-dimensional dangling list representing T x U # alignments = [[]] hypothesis.alignments = [[]] # For timestep t in X_t for time_idx in range(out_len): # Extract encoder embedding at timestep t # f = x[time_idx, :, :].unsqueeze(0) # [1, 1, D] f = x.narrow(dim=0, start=time_idx, length=1) # Setup exit flags and counter not_blank = True symbols_added = 0 # While blank is not predicted, or we dont run out of max symbols per timestep while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): # In the first timestep, we initialize the network with RNNT Blank # In later timesteps, we provide previous predicted label as input. if hypothesis.last_token is None and hypothesis.dec_state is None: last_label = self._SOS else: last_label = label_collate([[hypothesis.last_token]]) # Perform prediction network and joint network steps. g, hidden_prime = self._pred_step(last_label, hypothesis.dec_state) logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :] del g # torch.max(0) op doesnt exist for FP 16. if logp.dtype != torch.float32: logp = logp.float() # get index k, of max prob v, k = logp.max(0) k = k.item( ) # K is the label at timestep t_s in inner loop, s >= 0. if self.preserve_alignments: # insert logits into last timestep hypothesis.alignments[-1].append(k) del logp # If blank token is predicted, exit inner loop, move onto next timestep t if k == self._blank_index: not_blank = False if self.preserve_alignments: # convert Ti-th logits into a torch array hypothesis.alignments.append( []) # blank buffer for next timestep else: # Append token to label set, update RNN state. hypothesis.y_sequence.append(k) hypothesis.score += float(v) hypothesis.timestep.append(time_idx) hypothesis.dec_state = hidden_prime hypothesis.last_token = k # Increment token counter. symbols_added += 1 # Remove trailing empty list of Alignments if self.preserve_alignments: if len(hypothesis.alignments[-1]) == 0: del hypothesis.alignments[-1] # Unpack the hidden states hypothesis.dec_state = self.decoder.batch_select_state( hypothesis.dec_state, 0) return hypothesis