def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) h_length = int(h.size(0)) u_max = min(self.u_max, (h_length - 1)) init_tensor = h.unsqueeze(0) beam_state = self.decoder.init_state( torch.zeros((beam, self.hidden_size))) B = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] if self.lm: if hasattr(self.lm.predictor, "wordlm"): lm_model = self.lm.predictor.wordlm lm_type = "wordlm" else: lm_model = self.lm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for i in range(h_length + u_max): A = [] B_ = [] h_states = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u + 1 if t > (h_length - 1): continue B_.append(hyp) h_states.append((t, h[t])) if B_: beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, init_tensor) h_enc = torch.stack([h[1] for h in h_states]) beam_logp = torch.log_softmax(self.decoder.joint_network( h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.lm: beam_lm_states = create_lm_batch_state( [b.lm_state for b in B_], lm_type, lm_layers) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_)) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if h_states[i][0] == (h_length - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B
def nsc_beam_search(self, h: torch.Tensor) -> List[Hypothesis]: """N-step constrained beam search implementation. Based and modified from https://arxiv.org/pdf/2002.03577.pdf. Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet until further modifications. Note: the algorithm is not in his "complete" form but works almost as intended. Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) init_tensor = h.unsqueeze(0) blank_tensor = init_tensor.new_zeros(1, dtype=torch.long) beam_state = self.decoder.init_state( torch.zeros((beam, self.hidden_size))) init_tokens = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( init_tokens, beam_state, cache, init_tensor) state = self.decoder.select_state(beam_state, 0) if self.lm: beam_lm_states, beam_lm_scores = self.lm.buff_predict( None, beam_lm_tokens, 1) if hasattr(self.lm.predictor, "wordlm"): lm_model = self.lm.predictor.wordlm lm_type = "wordlm" else: lm_model = self.lm.predictor lm_type = "lm" lm_layers = len(lm_model.rnn) lm_state = select_lm_state(beam_lm_states, 0, lm_type, lm_layers) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=state, y=[beam_y[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for hi in h: hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) kept_hyps = [] h_enc = hi.unsqueeze(0) for j in range(len(hyps) - 1): for i in range((j + 1), len(hyps)): if (is_prefix(hyps[j].yseq, hyps[i].yseq) and (len(hyps[j].yseq) - len(hyps[i].yseq)) <= self.prefix_alpha): next_id = len(hyps[i].yseq) ytu = torch.log_softmax(self.decoder.joint_network( hi, hyps[i].y[-1]), dim=0) curr_score = hyps[i].score + float( ytu[hyps[j].yseq[next_id]]) for k in range(next_id, (len(hyps[j].yseq) - 1)): ytu = torch.log_softmax(self.decoder.joint_network( hi, hyps[j].y[k]), dim=0) curr_score += float(ytu[hyps[j].yseq[k + 1]]) hyps[j].score = np.logaddexp(hyps[j].score, curr_score) S = [] V = [] for n in range(self.nstep): beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) beam_logp = torch.log_softmax(self.decoder.joint_network( h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) if self.lm: beam_lm_scores = torch.stack( [hyp.lm_scores for hyp in hyps]) for i, hyp in enumerate(hyps): i_topk = ( torch.cat((beam_topk[0][i], beam_logp[i, 0:1])), torch.cat((beam_topk[1][i] + 1, blank_tensor)), ) for logp, k in zip(*i_topk): new_hyp = Hypothesis( yseq=hyp.yseq[:], score=(hyp.score + float(logp)), y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) if k == self.blank: S.append(new_hyp) else: new_hyp.yseq.append(int(k)) if self.lm: new_hyp.score += self.lm_weight * float( beam_lm_scores[i, k]) V.append(new_hyp) V = sorted(V, key=lambda x: x.score, reverse=True) V = substract(V, hyps)[:beam] l_state = [v.dec_state for v in V] l_tokens = [v.yseq for v in V] beam_state = self.decoder.create_batch_states( beam_state, l_state, l_tokens) beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( V, beam_state, cache, init_tensor) if self.lm: beam_lm_states = create_lm_batch_state( [v.lm_state for v in V], lm_type, lm_layers) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(V)) if n < (self.nstep - 1): for i, v in enumerate(V): v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.lm: v.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) v.lm_scores = beam_lm_scores[i] hyps = V[:] else: beam_logp = torch.log_softmax(self.decoder.joint_network( h_enc, beam_y), dim=-1) for i, v in enumerate(V): if self.nstep != 1: v.score += float(beam_logp[i, 0]) v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.lm: v.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) v.lm_scores = beam_lm_scores[i] kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(kept_hyps)
def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]: """N-step constrained beam search implementation. Based and modified from https://arxiv.org/pdf/2002.03577.pdf. Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet until further modifications. Note: the algorithm is not in his "complete" form but works almost as intended. Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) beam_state = self.decoder.init_state(beam) init_tokens = [ NSCHypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( init_tokens, beam_state, cache, self.use_lm, ) state = self.decoder.select_state(beam_state, 0) if self.use_lm: beam_lm_states, beam_lm_scores = self.lm.buff_predict( None, beam_lm_tokens, 1) lm_state = select_lm_state(beam_lm_states, 0, self.lm_layers, self.is_wordlm) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ NSCHypothesis( yseq=[self.blank], score=0.0, dec_state=state, y=[beam_y[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for hi in h: hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) kept_hyps = [] h_enc = hi.unsqueeze(0) for j, hyp_j in enumerate(hyps[:-1]): for hyp_i in hyps[(j + 1):]: curr_id = len(hyp_j.yseq) next_id = len(hyp_i.yseq) if (is_prefix(hyp_j.yseq, hyp_i.yseq) and (curr_id - next_id) <= self.prefix_alpha): ytu = torch.log_softmax(self.joint_network( hi, hyp_i.y[-1]), dim=-1) curr_score = hyp_i.score + float( ytu[hyp_j.yseq[next_id]]) for k in range(next_id, (curr_id - 1)): ytu = torch.log_softmax(self.joint_network( hi, hyp_j.y[k]), dim=-1) curr_score += float(ytu[hyp_j.yseq[k + 1]]) hyp_j.score = np.logaddexp(hyp_j.score, curr_score) S = [] V = [] for n in range(self.nstep): beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) beam_logp = torch.log_softmax(self.joint_network( h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) for i, hyp in enumerate(hyps): S.append( NSCHypothesis( yseq=hyp.yseq[:], score=hyp.score + float(beam_logp[i, 0:1]), y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, )) V.append(S[-1]) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): score = hyp.score + float(logp) if self.use_lm: score += self.lm_weight * float(hyp.lm_scores[k]) V.append( NSCHypothesis( yseq=hyp.yseq[:] + [int(k)], score=score, y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, )) V.sort(key=lambda x: x.score, reverse=True), V = substract(V, hyps)[:beam] beam_state = self.decoder.create_batch_states( beam_state, [v.dec_state for v in V], [v.yseq for v in V], ) beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( V, beam_state, cache, self.use_lm, ) if self.use_lm: beam_lm_states = create_lm_batch_state( [v.lm_state for v in V], self.lm_layers, self.is_wordlm) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(V)) if n < (self.nstep - 1): for i, v in enumerate(V): v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm) v.lm_scores = beam_lm_scores[i] hyps = V[:] else: beam_logp = torch.log_softmax(self.joint_network( h_enc, beam_y), dim=-1) for i, v in enumerate(V): if self.nstep != 1: v.score += float(beam_logp[i, 0]) v.y.append(beam_y[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm) v.lm_scores = beam_lm_scores[i] kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(kept_hyps)
def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoded speech features (T_max, D_enc) Returns: nbest_hyps: N-best decoding results """ beam = min(self.beam_size, self.vocab_size) init_tensor = h.unsqueeze(0) beam_state = self.decoder.init_state( torch.zeros((beam, self.hidden_size))) B = [ Hypothesis( yseq=[self.blank], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] if self.lm: if hasattr(self.lm.predictor, "wordlm"): lm_model = self.lm.predictor.wordlm lm_type = "wordlm" else: lm_model = self.lm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for hi in h: A = [] C = B h_enc = hi.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, init_tensor) beam_logp = torch.log_softmax(self.decoder.joint_network( h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, )) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))) if v < self.max_sym_exp: if self.lm: beam_lm_states = create_lm_batch_state( [c.lm_state for c in C], lm_type, lm_layers) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(C)) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state( beam_state, i), lm_state=hyp.lm_state, ) if self.lm: new_hyp.score += self.lm_weight * beam_lm_scores[ i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
def nsc_beam_search(self, enc_out: torch.Tensor) -> List[ExtendedHypothesis]: """N-step constrained beam search implementation. Based on/Modified from https://arxiv.org/pdf/2002.03577.pdf. Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet until further modifications. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_k = min(beam, (self.vocab_size - 1)) beam_state = self.decoder.init_state(beam) init_tokens = [ ExtendedHypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( init_tokens, beam_state, cache, self.use_lm, ) state = self.decoder.select_state(beam_state, 0) if self.use_lm: beam_lm_states, beam_lm_scores = self.lm.buff_predict( None, beam_lm_tokens, 1 ) lm_state = select_lm_state( beam_lm_states, 0, self.lm_layers, self.is_wordlm ) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ ExtendedHypothesis( yseq=[self.blank_id], score=0.0, dec_state=state, dec_out=[beam_dec_out[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for enc_out_t in enc_out: hyps = self.prefix_search( sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True), enc_out_t, ) kept_hyps = [] beam_enc_out = enc_out_t.unsqueeze(0) S = [] V = [] for n in range(self.nstep): beam_dec_out = torch.stack([hyp.dec_out[-1] for hyp in hyps]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) for i, hyp in enumerate(hyps): S.append( ExtendedHypothesis( yseq=hyp.yseq[:], score=hyp.score + float(beam_logp[i, 0:1]), dec_out=hyp.dec_out[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) ) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): score = hyp.score + float(logp) if self.use_lm: score += self.lm_weight * float(hyp.lm_scores[k]) V.append( ExtendedHypothesis( yseq=hyp.yseq[:] + [int(k)], score=score, dec_out=hyp.dec_out[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) ) V.sort(key=lambda x: x.score, reverse=True) V = subtract(V, hyps)[:beam] beam_state = self.decoder.create_batch_states( beam_state, [v.dec_state for v in V], [v.yseq for v in V], ) beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( V, beam_state, cache, self.use_lm, ) if self.use_lm: beam_lm_states = create_lm_batch_states( [v.lm_state for v in V], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(V) ) if n < (self.nstep - 1): for i, v in enumerate(V): v.dec_out.append(beam_dec_out[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) v.lm_scores = beam_lm_scores[i] hyps = V[:] else: beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) for i, v in enumerate(V): if self.nstep != 1: v.score += float(beam_logp[i, 0]) v.dec_out.append(beam_dec_out[i]) v.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: v.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) v.lm_scores = beam_lm_scores[i] kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(kept_hyps)
def modified_adaptive_expansion_search( self, enc_out: torch.Tensor ) -> List[ExtendedHypothesis]: """It's the modified Adaptive Expansion Search (mAES) implementation. Based on/modified from https://ieeexplore.ieee.org/document/9250505 and NSC. Args: enc_out: Encoder output sequence. (T, D_enc) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) init_tokens = [ ExtendedHypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( init_tokens, beam_state, cache, self.use_lm, ) state = self.decoder.select_state(beam_state, 0) if self.use_lm: beam_lm_states, beam_lm_scores = self.lm.buff_predict( None, beam_lm_tokens, 1 ) lm_state = select_lm_state( beam_lm_states, 0, self.lm_layers, self.is_wordlm ) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ ExtendedHypothesis( yseq=[self.blank_id], score=0.0, dec_state=state, dec_out=[beam_dec_out[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for enc_out_t in enc_out: hyps = self.prefix_search( sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True), enc_out_t, ) kept_hyps = [] beam_enc_out = enc_out_t.unsqueeze(0) list_b = [] for n in range(self.nstep): beam_dec_out = torch.stack([h.dec_out[-1] for h in hyps]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) k_expansions = select_k_expansions( hyps, beam_logp, beam, self.expansion_gamma, self.expansion_beta ) list_exp = [] for i, hyp in enumerate(hyps): for k, new_score in k_expansions[i]: new_hyp = ExtendedHypothesis( yseq=hyp.yseq[:], score=new_score, dec_out=hyp.dec_out[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) if k == 0: list_b.append(new_hyp) else: new_hyp.yseq.append(int(k)) if self.use_lm: new_hyp.score += self.lm_weight * float( hyp.lm_scores[k] ) list_exp.append(new_hyp) if not list_exp: kept_hyps = sorted(list_b, key=lambda x: x.score, reverse=True)[ :beam ] break else: beam_state = self.decoder.create_batch_states( beam_state, [hyp.dec_state for hyp in list_exp], [hyp.yseq for hyp in list_exp], ) beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( list_exp, beam_state, cache, self.use_lm, ) if self.use_lm: beam_lm_states = create_lm_batch_states( [hyp.lm_state for hyp in list_exp], self.lm_layers, self.is_wordlm, ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(list_exp) ) if n < (self.nstep - 1): for i, hyp in enumerate(list_exp): hyp.dec_out.append(beam_dec_out[i]) hyp.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) hyp.lm_scores = beam_lm_scores[i] hyps = list_exp[:] else: beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) for i, hyp in enumerate(list_exp): hyp.score += float(beam_logp[i, 0]) hyp.dec_out.append(beam_dec_out[i]) hyp.dec_state = self.decoder.select_state(beam_state, i) if self.use_lm: hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) hyp.lm_scores = beam_lm_scores[i] kept_hyps = sorted( list_b + list_exp, key=lambda x: x.score, reverse=True )[:beam] return self.sort_nbest(kept_hyps)
def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: h: Encoder output sequences. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) t_max = int(enc_out.size(0)) u_max = min(self.u_max, (t_max - 1)) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] final = [] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for i in range(t_max + u_max): A = [] B_ = [] B_enc_out = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u if t > (t_max - 1): continue B_.append(hyp) B_enc_out.append((t, enc_out[t])) if B_: beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( B_, beam_state, cache, self.use_lm, ) beam_enc_out = torch.stack([x[1] for x in B_enc_out]) beam_logp = torch.log_softmax( self.joint_network(beam_enc_out, beam_dec_out) / self.softmax_temperature, dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if self.use_lm: beam_lm_states = create_lm_batch_states( [b.lm_state for b in B_], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_) ) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if B_enc_out[i][0] == (t_max - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: return self.sort_nbest(final) else: return B
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: enc_out: Encoder output sequence. (T, D) Returns: nbest_hyps: N-best hypothesis. """ beam = min(self.beam_size, self.vocab_size) beam_state = self.decoder.init_state(beam) B = [ Hypothesis( yseq=[self.blank_id], score=0.0, dec_state=self.decoder.select_state(beam_state, 0), ) ] cache = {} if self.use_lm and not self.is_wordlm: B[0].lm_state = init_lm_state(self.lm_predictor) for enc_out_t in enc_out: A = [] C = B enc_out_t = enc_out_t.unsqueeze(0) for v in range(self.max_sym_exp): D = [] beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score( C, beam_state, cache, self.use_lm, ) beam_logp = torch.log_softmax( self.joint_network(enc_out_t, beam_dec_out) / self.softmax_temperature, dim=-1, ) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) ) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) ) if v < (self.max_sym_exp - 1): if self.use_lm: beam_lm_states = create_lm_batch_states( [c.lm_state for c in C], self.lm_layers, self.is_wordlm ) beam_lm_states, beam_lm_scores = self.lm.buff_predict( beam_lm_states, beam_lm_tokens, len(C) ) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=self.decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if self.use_lm: new_hyp.score += self.lm_weight * beam_lm_scores[i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, self.lm_layers, self.is_wordlm ) D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] return self.sort_nbest(B)
def nsc_beam_search(decoder, h, recog_args, rnnlm=None): """N-step constrained beam search implementation. Based and modified from https://arxiv.org/pdf/2002.03577.pdf. Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet until further modifications. Note: the algorithm is not in his "complete" form but works almost as intended. Args: decoder (class): decoder class h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) recog_args (Namespace): argument Namespace containing options rnnlm (torch.nn.Module): language module Returns: nbest_hyps (list of dicts): n-best decoding results """ beam = min(recog_args.beam_size, decoder.odim) beam_k = min(beam, (decoder.odim - 1)) nstep = recog_args.nstep prefix_alpha = recog_args.prefix_alpha nbest = recog_args.nbest cache = {} init_tensor = h.unsqueeze(0) blank_tensor = init_tensor.new_zeros(1, dtype=torch.long) beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits))) init_tokens = [ Hypothesis( yseq=[decoder.blank], score=0.0, dec_state=decoder.select_state(beam_state, 0), ) ] beam_y, beam_state, beam_lm_tokens = decoder.batch_score( init_tokens, beam_state, cache, init_tensor) state = decoder.select_state(beam_state, 0) if rnnlm: beam_lm_states, beam_lm_scores = rnnlm.buff_predict( None, beam_lm_tokens, 1) if hasattr(rnnlm.predictor, "wordlm"): lm_model = rnnlm.predictor.wordlm lm_type = "wordlm" else: lm_model = rnnlm.predictor lm_type = "lm" lm_layers = len(lm_model.rnn) lm_state = select_lm_state(beam_lm_states, 0, lm_type, lm_layers) lm_scores = beam_lm_scores[0] else: lm_state = None lm_scores = None kept_hyps = [ Hypothesis( yseq=[decoder.blank], score=0.0, dec_state=state, y=[beam_y[0]], lm_state=lm_state, lm_scores=lm_scores, ) ] for hi in h: hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) kept_hyps = [] h_enc = hi.unsqueeze(0) for j in range(len(hyps) - 1): for i in range((j + 1), len(hyps)): if (is_prefix(hyps[j].yseq, hyps[i].yseq) and (len(hyps[j].yseq) - len(hyps[i].yseq)) <= prefix_alpha): next_id = len(hyps[i].yseq) ytu = F.log_softmax(decoder.joint(hi, hyps[i].y[-1]), dim=0) curr_score = hyps[i].score + float( ytu[hyps[j].yseq[next_id]]) for k in range(next_id, (len(hyps[j].yseq) - 1)): ytu = F.log_softmax(decoder.joint(hi, hyps[j].y[k]), dim=0) curr_score += float(ytu[hyps[j].yseq[k + 1]]) hyps[j].score = np.logaddexp(hyps[j].score, curr_score) S = [] V = [] for n in range(nstep): beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) if rnnlm: beam_lm_scores = torch.stack([hyp.lm_scores for hyp in hyps]) for i, hyp in enumerate(hyps): i_topk = ( torch.cat((beam_topk[0][i], beam_logp[i, 0:1])), torch.cat((beam_topk[1][i] + 1, blank_tensor)), ) for logp, k in zip(*i_topk): new_hyp = Hypothesis( yseq=hyp.yseq[:], score=(hyp.score + float(logp)), y=hyp.y[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, lm_scores=hyp.lm_scores, ) if k == decoder.blank: S.append(new_hyp) else: new_hyp.yseq.append(int(k)) if rnnlm: new_hyp.score += recog_args.lm_weight * float( beam_lm_scores[i, k]) V.append(new_hyp) V = sorted(V, key=lambda x: x.score, reverse=True) V = substract(V, hyps)[:beam] l_state = [v.dec_state for v in V] l_tokens = [v.yseq for v in V] beam_state = decoder.create_batch_states(beam_state, l_state, l_tokens) beam_y, beam_state, beam_lm_tokens = decoder.batch_score( V, beam_state, cache, init_tensor) if rnnlm: beam_lm_states = create_lm_batch_state([v.lm_state for v in V], lm_type, lm_layers) beam_lm_states, beam_lm_scores = rnnlm.buff_predict( beam_lm_states, beam_lm_tokens, len(V)) if n < (nstep - 1): for i, v in enumerate(V): v.y.append(beam_y[i]) v.dec_state = decoder.select_state(beam_state, i) if rnnlm: v.lm_state = select_lm_state(beam_lm_states, i, lm_type, lm_layers) v.lm_scores = beam_lm_scores[i] hyps = V[:] else: beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1) for i, v in enumerate(V): if nstep != 1: v.score += float(beam_logp[i, 0]) v.y.append(beam_y[i]) v.dec_state = decoder.select_state(beam_state, i) if rnnlm: v.lm_state = select_lm_state(beam_lm_states, i, lm_type, lm_layers) v.lm_scores = beam_lm_scores[i] kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] nbest_hyps = sorted(kept_hyps, key=lambda x: (x.score / len(x.yseq)), reverse=True)[:nbest] return [asdict(n) for n in nbest_hyps]
def align_length_sync_decoding(decoder, h, recog_args, rnnlm=None): """Alignment-length synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: decoder (class): decoder class h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) recog_args (Namespace): argument Namespace containing options rnnlm (torch.nn.Module): language module Returns: nbest_hyps (list of dicts): n-best decoding results """ beam = min(recog_args.beam_size, decoder.odim) h_length = int(h.size(0)) u_max = min(recog_args.u_max, (h_length - 1)) nbest = recog_args.nbest init_tensor = h.unsqueeze(0) beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits))) B = [ Hypothesis( yseq=[decoder.blank], score=0.0, dec_state=decoder.select_state(beam_state, 0), ) ] final = [] if rnnlm: if hasattr(rnnlm.predictor, "wordlm"): lm_model = rnnlm.predictor.wordlm lm_type = "wordlm" else: lm_model = rnnlm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for i in range(h_length + u_max): A = [] B_ = [] h_states = [] for hyp in B: u = len(hyp.yseq) - 1 t = i - u + 1 if t > (h_length - 1): continue B_.append(hyp) h_states.append((t, h[t])) if B_: beam_y, beam_state, beam_lm_tokens = decoder.batch_score( B_, beam_state, cache, init_tensor) h_enc = torch.stack([h[1] for h in h_states]) beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) if rnnlm: beam_lm_states = create_lm_batch_state( [b.lm_state for b in B_], lm_type, lm_layers) beam_lm_states, beam_lm_scores = rnnlm.buff_predict( beam_lm_states, beam_lm_tokens, len(B_)) for i, hyp in enumerate(B_): new_hyp = Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, ) A.append(new_hyp) if h_states[i][0] == (h_length - 1): final.append(new_hyp) for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq[:] + [int(k)]), dec_state=decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if rnnlm: new_hyp.score += recog_args.lm_weight * beam_lm_scores[ i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) A.append(new_hyp) B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] B = recombine_hyps(B) if final: nbest_hyps = sorted(final, key=lambda x: x.score, reverse=True)[:nbest] else: nbest_hyps = B[:nbest] return [asdict(n) for n in nbest_hyps]
def time_sync_decoding(decoder, h, recog_args, rnnlm=None): """Time synchronous beam search implementation. Based on https://ieeexplore.ieee.org/document/9053040 Args: decoder (class): decoder class h (torch.Tensor): encoder hidden state sequences (Tmax, Henc) recog_args (Namespace): argument Namespace containing options rnnlm (torch.nn.Module): language module Returns: nbest_hyps (list of dicts): n-best decoding results """ beam = min(recog_args.beam_size, decoder.odim) max_sym_exp = recog_args.max_sym_exp nbest = recog_args.nbest init_tensor = h.unsqueeze(0) beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits))) B = [ Hypothesis( yseq=[decoder.blank], score=0.0, dec_state=decoder.select_state(beam_state, 0), ) ] if rnnlm: if hasattr(rnnlm.predictor, "wordlm"): lm_model = rnnlm.predictor.wordlm lm_type = "wordlm" else: lm_model = rnnlm.predictor lm_type = "lm" B[0].lm_state = init_lm_state(lm_model) lm_layers = len(lm_model.rnn) cache = {} for hi in h: A = [] C = B h_enc = hi.unsqueeze(0) for v in range(max_sym_exp): D = [] beam_y, beam_state, beam_lm_tokens = decoder.batch_score( C, beam_state, cache, init_tensor) beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1) beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) seq_A = [h.yseq for h in A] for i, hyp in enumerate(C): if hyp.yseq not in seq_A: A.append( Hypothesis( score=(hyp.score + float(beam_logp[i, 0])), yseq=hyp.yseq[:], dec_state=hyp.dec_state, lm_state=hyp.lm_state, )) else: dict_pos = seq_A.index(hyp.yseq) A[dict_pos].score = np.logaddexp( A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))) if v < max_sym_exp: if rnnlm: beam_lm_states = create_lm_batch_state( [c.lm_state for c in C], lm_type, lm_layers) beam_lm_states, beam_lm_scores = rnnlm.buff_predict( beam_lm_states, beam_lm_tokens, len(C)) for i, hyp in enumerate(C): for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): new_hyp = Hypothesis( score=(hyp.score + float(logp)), yseq=(hyp.yseq + [int(k)]), dec_state=decoder.select_state(beam_state, i), lm_state=hyp.lm_state, ) if rnnlm: new_hyp.score += recog_args.lm_weight * beam_lm_scores[ i, k] new_hyp.lm_state = select_lm_state( beam_lm_states, i, lm_type, lm_layers) D.append(new_hyp) C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] nbest_hyps = sorted(B, key=lambda x: x.score, reverse=True)[:nbest] return [asdict(n) for n in nbest_hyps]