def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) h_length = int(h.size(0)) u_max = min(self.u_max, (h_length - 1)) init_tensor = h.unsqueeze(0) beam_state = self.decoder.init_state( torch.zeros((beam, self.hidden_size))) B = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] if self.lm: if hasattr(self.lm.predictor, "wordlm"): lm_model = self.lm.predictor.wordlm lm_type = "wordlm" else: lm_model = self.lm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for i in range(h_length + u_max): A = [] B_ = [] h_states = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u + 1 if t > (h_length - 1): continue B_.append(hyp) h_states.append((t, h[t])) if B_: beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, init_tensor) h_enc = torch.stack([h[1] for h in h_states]) beam_logp = torch.log_softmax(self.decoder.joint_network( h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.lm: beam_lm_states = create_lm_batch_state( [b.lm_state for b in B_], lm_type, lm_layers) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_)) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if h_states[i][0] == (h_length - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B
def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoder output sequences. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) t_max = int(enc_out.size(0)) u_max = min(self.u_max, (t_max - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm: B[0].lm_state = self.lm.zero_state() for i in range(t_max + u_max): A = [] B_ = [] B_enc_out = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u if t > (t_max - 1): continue B_.append(hyp) B_enc_out.append((t, enc_out[t])) if B_: beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) beam_enc_out = torch.stack([x[1] for x in B_enc_out]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out), dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_scores, beam_lm_states = self.lm.batch_score( beam_lm_tokens, [b.lm_state for b in B_], None, ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if B_enc_out[i][0] == (t_max - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = beam_lm_states[i] A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B
def align_length_sync_decoding(decoder, h, recog_args, rnnlm=None): """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: decoder (class): decoder class h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) recog_args (Namespace): argument Namespace containing options rnnlm (torch.nn.Module): language module Returns: nbest_hyps (list of dicts): n-best decoding results """ beam = min(recog_args.beam_size, decoder.odim) h_length = int(h.size(0)) u_max = min(recog_args.u_max, (h_length - 1)) nbest = recog_args.nbest init_tensor = h.unsqueeze(0) beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits))) B = [ Hypothesis( yseq=[decoder.blank], score=0.0, dec_state=decoder.select_state(beam_state, 0), ) ] final = [] if rnnlm: if hasattr(rnnlm.predictor, "wordlm"): lm_model = rnnlm.predictor.wordlm lm_type = "wordlm" else: lm_model = rnnlm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for i in range(h_length + u_max): A = [] B_ = [] h_states = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u + 1 if t > (h_length - 1): continue B_.append(hyp) h_states.append((t, h[t])) if B_: beam_y, beam_state, beam_lm_tokens = decoder.batch_score( B_, beam_state, cache, init_tensor) h_enc = torch.stack([h[1] for h in h_states]) beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if rnnlm: beam_lm_states = create_lm_batch_state( [b.lm_state for b in B_], lm_type, lm_layers) beam_lm_states, beam_lm_scores = rnnlm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_)) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if h_states[i][0] == (h_length - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if rnnlm: new_hyp.score += recog_args.lm_weight * beam_lm_scores[ i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: nbest_hyps = sorted(final, key=lambda x: x.score, reverse=True)[:nbest] else: nbest_hyps = B[:nbest] return [asdict(n) for n in nbest_hyps]