def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]: """Greedy search implementation for transformer-transducer. Args: h: Encoded speech features (T_max, D_enc) Returns: hyp: 1-best decoding results """ dec_state = self.decoder.init_state(1) hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state) cache = {} y, state, _ = self.decoder.score(hyp, cache) for i, hi in enumerate(h): ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) logp, pred = torch.max(ytu, dim=-1) if pred != self.blank: hyp.yseq.append(int(pred)) hyp.score += float(logp) hyp.dec_state = state y, state, _ = self.decoder.score(hyp, cache) return [hyp]
def greedy_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Greedy search implementation. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: hyp: 1-best hypotheses. """ dec_state = self.decoder.init_state(1) hyp = Hypothesis(score=0.0, yseq=[self.blank_id], dec_state=dec_state) cache = {} dec_out, state, _ = self.decoder.score(hyp, cache) for enc_out_t in enc_out: logp = torch.log_softmax( self.joint_network(enc_out_t, dec_out, quantization=self.quantization) / self.softmax_temperature, dim=-1, ) top_logp, pred = torch.max(logp, dim=-1) if pred != self.blank_id: hyp.yseq.append(int(pred)) hyp.score += float(top_logp) hyp.dec_state = state dec_out, state, _ = self.decoder.score(hyp, cache) return [hyp]