class TransformerDecoder(nn.Module): def __init__(self, params, cmlm=False): super(TransformerDecoder, self).__init__() self.vocab_size = params.vocab_size self.embed = nn.Embedding(self.vocab_size, params.dec_hidden_size) self.pe = PositionalEncoder(params.dec_hidden_size, params.dropout_dec_rate) self.dec_num_layers = params.dec_num_layers # TODO: rename to `decoders` self.transformers = nn.ModuleList() for _ in range(self.dec_num_layers): self.transformers += [ TransformerDecoderLayer( dec_num_attention_heads=params.dec_num_attention_heads, dec_hidden_size=params.dec_hidden_size, dec_intermediate_size=params.dec_intermediate_size, dropout_dec_rate=params.dropout_dec_rate, dropout_attn_rate=params.dropout_attn_rate, ) ] self.mtl_ctc_weight = params.mtl_ctc_weight if self.mtl_ctc_weight > 0: self.ctc = CTCDecoder(params) if hasattr(params, "mtl_ctc_add_sos_eos"): self.mtl_ctc_add_sos_eos = params.mtl_ctc_add_sos_eos # normalize before # TODO: set `eps` to 1e-5 (default) self.norm = nn.LayerNorm(params.dec_hidden_size, eps=1e-12) self.output = nn.Linear(params.dec_hidden_size, self.vocab_size) self.cmlm = cmlm if self.cmlm: # TODO: label smoothing self.loss_fn = MaskedLMLoss(vocab_size=self.vocab_size) else: self.loss_fn = LabelSmoothingLoss( vocab_size=self.vocab_size, lsm_prob=params.lsm_prob, normalize_length=params.loss_normalize_length, normalize_batch=params.loss_normalize_batch, ) self.kd_weight = params.kd_weight if self.kd_weight > 0: self.loss_fn = DistillLoss( vocab_size=self.vocab_size, soft_label_weight=params.kd_weight, lsm_prob=params.lsm_prob, normalize_length=params.loss_normalize_length, normalize_batch=params.loss_normalize_batch, ) self.blank_id = params.blank_id self.eos_id = params.eos_id self.max_decode_ylen = params.max_decode_ylen def forward( self, eouts, elens, eouts_inter=None, ys=None, ylens=None, ys_in=None, ys_out=None, # labels soft_labels=None, ps=None, plens=None, ): loss = 0 loss_dict = {} # embedding + positional encoding ys_in = self.pe(self.embed(ys_in)) emask = make_src_mask(elens) if self.cmlm: # Conditional Masked LM ymask = make_src_mask(ylens) else: ymask = make_tgt_mask(ylens + 1) for layer_id in range(self.dec_num_layers): ys_in, ymask, eouts, emask = self.transformers[layer_id](ys_in, ymask, eouts, emask) ys_in = self.norm(ys_in) # normalize before logits = self.output(ys_in) if ys_out is None: return logits if self.kd_weight > 0 and soft_labels is not None: # NOTE: ys_out (label) have length ylens+1 loss_att_kd, loss_kd, loss_att = self.loss_fn( logits, ys_out, soft_labels, ylens + 1) loss += loss_att_kd loss_dict["loss_kd"] = loss_kd loss_dict["loss_att"] = loss_att else: if self.cmlm: loss_att = self.loss_fn(logits, labels=ys_out, ylens=None) else: # NOTE: ys_out (label) have length ylens+1 loss_att = self.loss_fn(logits, ys_out, ylens + 1) loss += loss_att loss_dict["loss_att"] = loss_att if self.mtl_ctc_weight > 0: if self.mtl_ctc_add_sos_eos: ys, ylens = add_sos_eos(ys, ylens, eos_id=self.eos_id) # NOTE: KD is not applied to auxiliary CTC loss_ctc, _, _ = self.ctc(eouts=eouts, elens=elens, ys=ys, ylens=ylens, soft_labels=None) loss += self.mtl_ctc_weight * loss_ctc # auxiliary loss loss_dict["loss_ctc"] = loss_ctc loss_dict["loss_total"] = loss return loss, loss_dict, logits def forward_one_step(self, ys_in, ylens_in, eouts): ys_in = self.pe(self.embed(ys_in)) ymask = make_tgt_mask(ylens_in) for layer_id in range(self.dec_num_layers): ys_in, ymask, eouts, _ = self.transformers[layer_id](ys_in, ymask, eouts, None) ys_in = self.norm(ys_in[:, -1]) # normalize before logits = self.output(ys_in) return logits def decode( self, eouts, elens, eouts_inter=None, beam_width=1, len_weight=0, lm=None, lm_weight=0, decode_ctc_weight=0, decode_phone=False, ): """ Beam search decoding """ bs = eouts.size(0) if decode_ctc_weight == 1: print("CTC is used") # greedy return self.ctc.decode(eouts, elens, beam_width=1) assert bs == 1 # init beam = { "hyp": [self.eos_id], "score": 0.0, "score_ctc": 0.0, "ctc_state": None, "score_lm": 0.0, "lm_state": None, } if decode_ctc_weight > 0: ctc_logits = self.ctc(eouts, elens) ctc_log_probs = torch.log_softmax(ctc_logits, dim=-1) ctc_scorer = CTCPrefixScorer( tensor2np(ctc_log_probs.squeeze(0)), blank_id=self.blank_id, eos_id=self.eos_id, ) beam["score_ctc"] = 0.0 beam["ctc_state"] = ctc_scorer.initial_state() ctc_beam_width = min(ctc_log_probs.size(2), int(beam_width * CTC_BEAM_WIDTH_RATIO)) beams = [beam] results = [] for i in range(self.max_decode_ylen): new_beams = [] for beam in beams: ys_in = torch.tensor([beam["hyp"]]).to(eouts.device) ylens_in = torch.tensor([i + 1]).to(eouts.device) scores_att = torch.log_softmax(self.forward_one_step( ys_in, ylens_in, eouts), dim=-1) # (1, vocab) scores = scores_att if lm_weight > 0: scores_lm, _ = lm.predict(ys_in, ylens_in, states=None) # (1, vocab) scores += lm_weight * scores_lm[:, :self.vocab_size] if decode_ctc_weight > 0: score_ctc_prev = beam["score_ctc"] ctc_state_prev = beam["ctc_state"] scores_topb, v_topb = torch.topk(scores, k=ctc_beam_width, dim=1) scores_ctc, ctc_state = ctc_scorer(beam["hyp"], v_topb[0], ctc_state_prev) # re-calculate score scores = (1 - decode_ctc_weight) * scores_att[:, v_topb[ 0]] + decode_ctc_weight * np2tensor(scores_ctc - score_ctc_prev) if lm_weight > 0: scores += lm_weight * scores_lm[:, v_topb[0]] scores_topk, ids_topk = torch.topk(scores, k=beam_width, dim=1) v_topk = v_topb[:, ids_topk[0]] else: scores_topk, v_topk = torch.topk(scores, k=beam_width, dim=1) for j in range(beam_width): new_beam = {} new_beam["score"] = beam["score"] + float(scores_topk[0, j]) new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])] if decode_ctc_weight > 0: new_beam["score_ctc"] = scores_ctc[ids_topk[0, j]] new_beam["ctc_state"] = ctc_state[ids_topk[0, j]] new_beams.append(new_beam) # update `beams` beams = sorted(new_beams, key=lambda x: x["score"], reverse=True)[:beam_width] beams_extend = [] for beam in beams: # ended beams if beam["hyp"][-1] == self.eos_id: hyp_noeos = strip_eos(beam["hyp"], self.eos_id) # only <eos> is not acceptable if len(hyp_noeos) < 1: continue # add length penalty score = beam["score"] + len_weight * len(beam["hyp"]) results.append({"hyp": hyp_noeos, "score": score}) if len(results) >= beam_width: break else: beams_extend.append(beam) if len(results) >= beam_width: break beams = beams_extend results = sorted(results, key=lambda x: x["score"], reverse=True) hyps = [result["hyp"] for result in results] scores = [result["score"] for result in results] logits = None aligns = None return hyps, scores, logits, aligns
class LASDecoder(nn.Module): def __init__(self, params, phase="train"): super().__init__() self.enc_hidden_size = params.enc_hidden_size self.dec_hidden_size = params.dec_hidden_size self.dec_num_layers = params.dec_num_layers self.mtl_ctc_weight = params.mtl_ctc_weight if self.mtl_ctc_weight > 0: self.ctc = CTCDecoder(params) self.embed = nn.Embedding(params.vocab_size, params.embedding_size) self.dropout_emb = nn.Dropout(p=params.dropout_dec_rate) # Recurrency self.rnns = nn.ModuleList() input_size = params.embedding_size + params.enc_hidden_size for _ in range(self.dec_num_layers): self.rnns += [nn.LSTMCell( input_size, params.dec_hidden_size, )] input_size = params.dec_hidden_size # Score self.score = AttentionLoc(key_dim=params.enc_hidden_size, query_dim=params.dec_hidden_size, attn_dim=params.attn_dim) # Generate self.intermed = nn.Linear( params.enc_hidden_size + params.dec_hidden_size, params.dec_intermediate_size) self.output = nn.Linear(params.dec_intermediate_size, params.vocab_size) self.dropout = nn.Dropout(p=params.dropout_dec_rate) self.loss_fn = LabelSmoothingLoss( vocab_size=params.vocab_size, lsm_prob=params.lsm_prob, normalize_length=params.loss_normalize_length, normalize_batch=params.loss_normalize_batch, ) self.kd_weight = params.kd_weight if self.kd_weight > 0: self.loss_fn = DistillLoss( vocab_size=params.vocab_size, soft_label_weight=self.kd_weight, lsm_prob=params.lsm_prob, normalize_length=params.loss_normalize_length, normalize_batch=params.loss_normalize_batch, ) self.eos_id = params.eos_id self.max_decode_ylen = params.max_decode_ylen def forward( self, eouts, elens, eouts_inter=None, ys=None, ylens=None, ys_in=None, ys_out=None, # labels soft_labels=None, ps=None, plens=None, ): loss = 0 loss_dict = {} bs = eouts.size(0) ys_emb = self.dropout_emb(self.embed(ys_in)) dstate = None # context vector ctx = eouts.new_zeros(bs, 1, self.enc_hidden_size) attn_weight = None attn_mask = make_nopad_mask(elens).unsqueeze(2) logits = [] for i in range(ys_in.size(1)): y_emb = ys_emb[:, i:i + 1] # (bs, 1, embedding_size) logit, ctx, dstate, attn_weight = self.forward_one_step( y_emb, ctx, eouts, dstate, attn_weight, attn_mask) logits.append(logit) # (bs, 1, dec_intermediate_size) logits = self.output(torch.cat(logits, dim=1)) # (bs, ylen, vocab) if self.kd_weight > 0 and soft_labels is not None: # NOTE: ys_out (label) have length ylens+1 loss_att_kd, loss_kd, loss_att = self.loss_fn( logits, ys_out, soft_labels, ylens + 1) loss += loss_att_kd loss_dict["loss_kd"] = loss_kd loss_dict["loss_att"] = loss_att else: loss_att = self.loss_fn(logits, ys_out, ylens + 1) loss += loss_att loss_dict["loss_att"] = loss_att if self.mtl_ctc_weight > 0: # NOTE: KD is not applied to auxiliary CTC loss_ctc, _, _ = self.ctc(eouts=eouts, elens=elens, ys=ys, ylens=ylens, soft_labels=None) loss += self.mtl_ctc_weight * loss_ctc # auxiliary loss loss_dict["loss_ctc"] = loss_ctc loss_dict["loss_total"] = loss return loss, loss_dict, logits def forward_one_step(self, y_emb, ctx, eouts, dstate, attn_weight, attn_mask=None): # Recurrency -> Score -> Generate dstate, douts_1, douts_top = self.recurrency( torch.cat([y_emb, ctx], dim=-1), dstate) ctx, attn_weight = self.score(eouts, eouts, douts_1, attn_weight, attn_mask) logit = self.generate(ctx, douts_top) return logit, ctx, dstate, attn_weight def recurrency(self, dins, dstate=None): bs = dins.size(0) douts = dins.squeeze(1) if dstate is None: dstate = {} dstate["hs"] = torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=dins.device) dstate["cs"] = torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=dins.device) new_hs, new_cs = [], [] for layer_id in range(self.dec_num_layers): h, c = self.rnns[layer_id]( douts, (dstate["hs"][layer_id], dstate["cs"][layer_id])) new_hs.append(h) new_cs.append(c) douts = self.dropout(h) if layer_id == 0: douts_1 = douts.unsqueeze(1) new_dstate = {} new_dstate["hs"] = torch.stack(new_hs, dim=0) new_dstate["cs"] = torch.stack(new_cs, dim=0) douts_top = douts.unsqueeze(1) return new_dstate, douts_1, douts_top def generate(self, ctx, douts): out = self.intermed(torch.cat([ctx, douts], dim=-1)) return torch.tanh(out) def decode( self, eouts, elens, eouts_inter=None, beam_width=1, len_weight=0, lm=None, lm_weight=0, decode_ctc_weight=0, decode_phone=False, ): """ Beam search decoding """ bs = eouts.size(0) if decode_ctc_weight == 1: print("CTC is used") # greedy return self.ctc.decode(eouts, elens, beam_width=1) assert bs == 1 # init beam = { "hyp": [self.eos_id], "dstate": None, "score": 0.0, "las_ctx": eouts.new_zeros(bs, 1, self.enc_hidden_size), "las_dstate": None, "las_attn_weight": None, "score_ctc": 0.0, "ctc_state": None, "score_lm": 0.0, "lm_state": None, } if decode_ctc_weight > 0: pass beams = [beam] results = [] for i in range(self.max_decode_ylen): new_beams = [] for beam in beams: y_in = torch.tensor([[beam["hyp"][-1]]]).to(eouts.device) y_emb = self.dropout_emb(self.embed(y_in)) ctx = beam["las_ctx"] dstate = beam["las_dstate"] attn_weight = beam["las_attn_weight"] logit, ctx, dstate, attn_weight = self.forward_one_step( y_emb, ctx, eouts, dstate, attn_weight) logit = self.output(logit) scores_att = torch.log_softmax(logit.squeeze(0), dim=-1) # (1, vocab) scores = scores_att if lm_weight > 0: pass if decode_ctc_weight > 0: pass else: scores_topk, v_topk = torch.topk(scores, k=beam_width, dim=1) for j in range(beam_width): new_beam = {} new_beam["score"] = beam["score"] + float(scores_topk[0, j]) new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])] # new_beam["las_ctx"] = ctx new_beam["las_dstate"] = dstate new_beam["las_attn_weight"] = attn_weight if decode_ctc_weight > 0: pass new_beams.append(new_beam) # update `beams` beams = sorted(new_beams, key=lambda x: x["score"], reverse=True)[:beam_width] beams_extend = [] for beam in beams: # ended beams if beam["hyp"][-1] == self.eos_id: hyp_noeos = strip_eos(beam["hyp"], self.eos_id) # only <eos> is not acceptable if len(hyp_noeos) < 1: continue # add length penalty score = beam["score"] + len_weight * len(beam["hyp"]) results.append({"hyp": hyp_noeos, "score": score}) if len(results) >= beam_width: break else: beams_extend.append(beam) if len(results) >= beam_width: break beams = beams_extend results = sorted(results, key=lambda x: x["score"], reverse=True) hyps = [result["hyp"] for result in results] scores = [result["score"] for result in results] logits = None aligns = None return hyps, scores, logits, aligns
class RNNTDecoder(nn.Module): def __init__(self, params, phase="train"): super(RNNTDecoder, self).__init__() self.dec_num_layers = params.dec_num_layers self.dec_hidden_size = params.dec_hidden_size self.eos_id = params.eos_id self.blank_id = params.blank_id self.max_seq_len = 256 self.mtl_ctc_weight = params.mtl_ctc_weight self.kd_weight = params.kd_weight # Prediction network (decoder) # TODO: -> class self.embed = nn.Embedding(params.vocab_size, params.embedding_size) self.dropout_emb = nn.Dropout(p=params.dropout_emb_rate) self.dropout = nn.Dropout(p=params.dropout_dec_rate) self.rnns = nn.ModuleList() input_size = params.embedding_size for _ in range(self.dec_num_layers): self.rnns += [ nn.LSTM( input_size=input_size, hidden_size=params.dec_hidden_size, num_layers=1, batch_first=True, ) ] input_size = params.dec_hidden_size # Joint network # TODO: -> class self.w_enc = nn.Linear(params.enc_hidden_size, params.joint_hidden_size) self.w_dec = nn.Linear(params.dec_hidden_size, params.joint_hidden_size) self.output = nn.Linear(params.joint_hidden_size, params.vocab_size) if self.mtl_ctc_weight > 0: self.ctc = CTCDecoder(params) if phase == "train": logging.info(f"warp_rnnt version: {warp_rnnt.__version__}") if self.kd_weight > 0 and phase == "train": self.kd_type = params.kd_type self.reduce_main_loss_kd = params.reduce_main_loss_kd if self.kd_type == "word": self.transducer_kd_loss = RNNTWordDistillLoss() elif self.kd_type == "align": self.transducer_kd_loss = RNNTAlignDistillLoss() # cuda init only if forced aligner is used from asr.modeling.decoders.rnnt_aligner import \ RNNTForcedAligner self.forced_aligner = RNNTForcedAligner(blank_id=self.blank_id) def forward( self, eouts, elens, eouts_inter=None, ys=None, ylens=None, ys_in=None, ys_out=None, soft_labels=None, ps=None, plens=None, ): loss = 0 loss_dict = {} # Prediction network douts, _ = self.recurrency(ys_in, dstate=None) # Joint network logits = self.joint(eouts, douts) # (B, T, L + 1, vocab) log_probs = torch.log_softmax(logits, dim=-1) assert log_probs.size(2) == ys.size(1) + 1 # NOTE: rnnt_loss only accepts ys, elens, ylens with torch.int loss_rnnt = warp_rnnt.rnnt_loss( log_probs, ys.int(), elens.int(), ylens.int(), average_frames=False, reduction="mean", blank=self.blank_id, gather=False, ) loss += loss_rnnt # main loss loss_dict["loss_rnnt"] = loss_rnnt if self.mtl_ctc_weight > 0: # NOTE: KD is not applied to auxiliary CTC loss_ctc, _, _ = self.ctc(eouts=eouts, elens=elens, ys=ys, ylens=ylens, soft_labels=None) loss += self.mtl_ctc_weight * loss_ctc # auxiliary loss loss_dict["loss_ctc"] = loss_ctc if self.kd_weight > 0 and soft_labels is not None: if self.kd_type == "word": loss_kd = self.transducer_kd_loss(logits, soft_labels, elens, ylens) elif self.kd_type == "align": aligns = self.forced_aligner(log_probs, elens, ys, ylens) loss_kd = self.transducer_kd_loss(logits, ys, soft_labels, aligns, elens, ylens) loss_dict["loss_kd"] = loss_kd if self.reduce_main_loss_kd: loss = (1 - self.kd_weight) * loss + self.kd_weight * loss_kd else: loss += self.kd_weight * loss_kd loss_dict["loss_total"] = loss return loss, loss_dict, logits def joint(self, eouts, douts): """ Joint network """ eouts = eouts.unsqueeze(2) # (B, T, 1, enc_hidden_size) douts = douts.unsqueeze(1) # (B, 1, L, dec_hidden_size) out = torch.tanh(self.w_enc(eouts) + self.w_dec(douts)) out = self.output(out) # (B, T, L, vocab) return out def recurrency(self, ys_in, dstate): """ Prediction network """ ys_emb = self.dropout_emb(self.embed(ys_in)) bs = ys_emb.size(0) if dstate is None: dstate = {} dstate["hs"] = torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=ys_in.device) dstate["cs"] = torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=ys_in.device) new_hs, new_cs = [], [] for layer_id in range(self.dec_num_layers): self.rnns[layer_id].flatten_parameters() ys_emb, (h, c) = self.rnns[layer_id]( ys_emb, hx=( dstate["hs"][layer_id:layer_id + 1], # (1, B, dec_hidden_size) dstate["cs"][layer_id:layer_id + 1], ), ) new_hs.append(h) new_cs.append(c) ys_emb = self.dropout(ys_emb) new_dstate = {} new_dstate["hs"] = torch.cat(new_hs, dim=0) new_dstate["cs"] = torch.cat(new_cs, dim=0) return ys_emb, new_dstate def _greedy(self, eouts, elens, decode_ctc_weight=0): """ Greedy decoding """ if decode_ctc_weight == 1: # greedy return self.ctc.decode(eouts, elens, beam_width=1) bs = eouts.size(0) hyps = [] scores = [] logits = None # TODO aligns = [] for b in range(bs): hyp = [] align = [] ys = eouts.new_zeros((1, 1), dtype=torch.long).fill_(self.eos_id) # <sos> dout, dstate = self.recurrency(ys, None) T = elens[b] t = 0 while t < T: out = self.joint(eouts[b:b + 1, t:t + 1], dout) # (B, 1, 1, vocab_size) new_ys = out.squeeze(2).argmax(-1) token_id = new_ys[0].item() align.append(token_id) if token_id == self.blank_id: t += 1 else: hyp.append(token_id) dout, dstate = self.recurrency(new_ys, dstate) if len(hyp) > self.max_seq_len: break hyps.append(hyp) # TODO scores.append(None) aligns.append(align) return hyps, scores, logits, aligns def _beam_search(self, eouts, elens, beam_width=1, len_weight=0, lm=None, lm_weight=0): """ Beam search decoding Reference: ALIGNMENT-LENGTH SYNCHRONOUS DECODING FOR RNN TRANSDUCER https://ieeexplore.ieee.org/document/9053040 """ bs = eouts.size(0) assert bs == 1 NUM_EXPANDS = 3 # init beam = { "hyp": [self.eos_id], # <sos> "score": 0.0, "score_asr": 0.0, "dstate": { "hs": torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=eouts.device), "cs": torch.zeros(self.dec_num_layers, bs, self.dec_hidden_size, device=eouts.device) } } beams = [beam] # time synchronous decoding for t in range(eouts.size(1)): new_beams = [] # A beams_v = beams[:] # C <- B for v in range(NUM_EXPANDS): new_beams_v = [] # D # prediction network ys = torch.zeros((len(beams_v), 1), dtype=torch.int64, device=eouts.device) for i, beam in enumerate(beams_v): ys[i] = beam["hyp"][-1] dstates_prev = { "hs": torch.cat([beam["dstate"]["hs"] for beam in beams_v], dim=1), "cs": torch.cat([beam["dstate"]["cs"] for beam in beams_v], dim=1) } douts, dstates = self.recurrency(ys, dstates_prev) # for i, beam in enumerate(beams_v): # beams_v[i]["dstate"] = {"hs": dstates["hs"][:, i:i + 1], # "cs": dstates["cs"][:, i:i + 1]} # joint network logits = self.joint(eouts[:, t:t + 1], douts) scores_asr = torch.log_softmax(logits.squeeze(2).squeeze(1), dim=-1) # blank expansion for i, beam in enumerate(beams_v): blank_score = scores_asr[i, self.blank_id].item() new_beams.append(beam.copy()) new_beams[-1]["score"] += blank_score new_beams[-1]["score_asr"] += blank_score # NOTE: do not update `dstate` for i, beam in enumerate(beams_v): beams_v[i]["dstate"] = { "hs": dstates["hs"][:, i:i + 1], "cs": dstates["cs"][:, i:i + 1] } # non-blank expansion if v < NUM_EXPANDS - 1: for i, beam in enumerate(beams_v): scores_topk, v_topk = torch.topk(scores_asr[i, 1:], k=beam_width, dim=-1, largest=True, sorted=True) v_topk += 1 for k in range(beam_width): v_index = v_topk[k].item() new_beams_v.append({ "hyp": beam["hyp"] + [v_index], "score": beam["score"] + scores_topk[k].item(), "score_asr": beam["score_asr"] + scores_topk[k].item(), "dout": None, "dstate": beam["dstate"] }) # Local pruning at each expansion new_beams_v = sorted(new_beams_v, key=lambda x: x["score"], reverse=True) new_beams_v = self._merge_rnnt_paths(new_beams_v) beams_v = new_beams_v[:beam_width] # C <- D # Local pruning at t-th index new_beams = sorted(new_beams, key=lambda x: x["score"], reverse=True) new_beams = self._merge_rnnt_paths(new_beams) beams = new_beams[:beam_width] # B <- A hyps = [beam["hyp"] for beam in beams] return hyps def decode( self, eouts, elens, eouts_inter=None, beam_width=1, len_weight=0, lm=None, lm_weight=0, decode_ctc_weight=0, decode_phone=False, ): if beam_width <= 1: hyps, scores, logits, aligns = self._greedy( eouts, elens, decode_ctc_weight) else: hyps = self._beam_search(eouts, elens, beam_width, len_weight, lm, lm_weight) scores, logits, aligns = None, None, None return hyps, scores, logits, aligns @staticmethod def _merge_rnnt_paths(beams): merged_beams = {} for beam in beams: hyp = ints2str(beam["hyp"]) if hyp in merged_beams.keys(): merged_beams[hyp]["score"] = np.logaddexp( merged_beams[hyp]["score"], beam["score"]) else: merged_beams[hyp] = beam return list(merged_beams.values())