def aggregate_logits(logits, aligns, blank_id, reduction="max"): assert logits.size(0) == len(aligns) xlen = logits.size(0) token_probs_allv = [] token_probs_allv_tmp = [] token_probs_v = [] token_id_prev = None for t in range(xlen): token_id = aligns[t] if token_id == blank_id: continue if token_id != aligns[t - 1] and token_id_prev is not None: token_probs_allv_tmp = np.array(token_probs_allv_tmp) if reduction == "max": index = np.argmax(token_probs_allv_tmp[:, token_id_prev]) token_probs_allv.append(token_probs_allv_tmp[index]) token_probs_v.append(token_probs_allv_tmp[index, token_id_prev]) token_probs_allv_tmp = [] token_probs_allv_tmp.append(tensor2np(torch.softmax(logits[t], dim=-1))) token_id_prev = token_id token_probs_allv_tmp = np.array(token_probs_allv_tmp) if reduction == "max": index = np.argmax(token_probs_allv_tmp[:, token_id_prev]) token_probs_allv.append(token_probs_allv_tmp[index]) token_probs_v.append(token_probs_allv_tmp[index, token_id_prev]) return np.array(token_probs_allv), np.array(token_probs_v)
def ppl_masked_lm(dataloader, model, device, mask_id, max_seq_len, vocab): cnt = 0 sum_logprob = 0 for i, data in enumerate(dataloader): if (i + 1) % LOG_STEP == 0: logging.info( f"{(i+1):>4} / {len(dataloader):>4} PPL: {math.exp(sum_logprob/cnt):.3f}" ) utt_id = data["utt_ids"][0] ys = data["ys_in"].to(device) # not masked ylens = data["ylens"].to(device) assert ys.size(0) == 1 # for P2WDataset ps = data["ps"].to(device) if "ps" in data else None plens = data["plens"].to(device) if "plens" in data else None if ys.size(1) > max_seq_len: logging.warning(f"input length longer than {max_seq_len:d} skip") continue if args.print_probs: print("********************") print(f"{utt_id}: {vocab.ids2text(tensor2np(ys[0]))}") for mask_pos in range(ys.size(1)): ys_masked = ys.clone() label = ys[0, mask_pos] ys_masked[0, mask_pos] = mask_id with torch.no_grad(): logits = model(ys_masked, ylens, labels=None, ps=ps, plens=plens) if args.print_probs: print(vocab.ids2text(tensor2np(ys_masked[0]))) # TODO: print phones probs = torch.softmax(logits, dim=-1) p_topk, v_topk = torch.topk(probs[0, mask_pos], k=5) print(f"{vocab.i2t[label.item()]} || " + " | ".join([ f"{vocab.i2t[v.item()]}: {p.item():.2f}" for p, v in zip(p_topk, v_topk) ])) logprobs = torch.log_softmax(logits, dim=-1) sum_logprob -= logprobs[0, mask_pos, label].item() cnt += 1 ppl = math.exp(sum_logprob / cnt) return cnt, ppl
def predict(self, ys, ylens, states=None): """ predict next token for Shallow Fusion """ attention_mask = make_nopad_mask(ylens).float().to(ys.device) with torch.no_grad(): (logits,) = self.transformer(ys, attention_mask, causal=True) log_probs = torch.log_softmax(logits, dim=-1) log_probs_next = [] bs = len(ys) for b in range(bs): log_probs_next.append(tensor2np(log_probs[b, ylens[b] - 1])) return torch.tensor(log_probs_next).to(ys.device), states
def save_feats(wav_path): with torch.no_grad(): wav, sr = torchaudio.load(wav_path) assert sr == 16000 wav *= 2**15 # kaldi lmfb = torchaudio.compliance.kaldi.fbank( wav, window_type="hamming", htk_compat=True, sample_frequency=16000, num_mel_bins=80, use_energy=False, ) lmfb = tensor2np(lmfb) npy_path = wav_path.replace(".wav", ".npy") np.save(npy_path, lmfb) lmfb_sum = np.sum(lmfb, axis=0) lmfb_sqsum = np.sum(lmfb * lmfb, axis=0) num_frames = lmfb.shape[0] return lmfb_sum, lmfb_sqsum, num_frames
def make_lm_label( df, model, device, save_path, topk=8, temp=3.0, add_sos_eos=False, eos_id=2, max_seq_len=256, ): labels = {} utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], [] # batch for i, row in enumerate(df.itertuples()): ids = str2ints(row.token_id) if add_sos_eos: if len(ids) <= max_seq_len - 2: ids = [eos_id] + ids + [eos_id] start_pos = row.start_pos + 1 end_pos = row.end_pos + 1 else: # reduce context ids = [eos_id] + ids[1:-1] + [eos_id] start_pos = row.start_pos end_pos = row.end_pos else: start_pos = row.start_pos end_pos = row.end_pos y = torch.tensor(ids) ylen = len(ids) utt_ids.append(row.utt_id) ys.append(y) ylens.append(ylen) start_poss.append(start_pos) end_poss.append(end_pos) # batchify if (i + 1) % BATCH_SIZE == 0 or (i + 1) == len(df): bs = len(ys) ys_pad = pad_sequence(ys, batch_first=True).to(device) ylens = torch.tensor(ylens).to(device) with torch.no_grad(): logits = model(ys_pad, ylens) for b in range(bs): utt_id = utt_ids[b] start_pos = start_poss[b] end_pos = end_poss[b] y = ys[b] for pos in range(start_pos, end_pos): if pos == 0: v_topk = np.array([y[pos]]) p_topk = np.array([1.0]) logging.warning(f"hard label is used: {v_topk}") else: o_sorted, v_sorted = torch.sort(logits[b, pos - 1], descending=True) o_topk = o_sorted[:topk] v_topk = tensor2np(v_sorted[:topk]) p_topk = tensor2np( torch.softmax((o_topk / temp), dim=0)) label = [] for v, p in zip(v_topk, p_topk): # NOTE: do not add <eos> to soft labels if add_sos_eos and v == eos_id: continue label.append((v, p)) if utt_id not in labels: # first token in utterance labels[utt_id] = [label] else: labels[utt_id].append(label) utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], [] if (i + 1) % LOG_STEP == 0: logging.info(f"{(i+1):>4} / {len(df):>4}") if (i + 1) == SAVE_STEP: save_tmp_path = save_path + ".tmp" with open(save_tmp_path, "wb") as f: pickle.dump(labels, f) logging.info(f"pickle is saved to {save_tmp_path}") with open(save_path, "wb") as f: pickle.dump(labels, f) logging.info(f"pickle is saved to {save_path}")
def _beam_search( self, eouts, elens, beam_width=1, len_weight=0, lm=None, lm_weight=0 ): """ Beam search decoding Reference: https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7 """ bs = eouts.size(0) assert bs == 1 logits = self.output(eouts) log_probs = torch.log_softmax(logits, dim=-1) # init beam = { "hyp": [self.eos_id], # <eos> is used for LM "score": 0.0, "p_b": 0.0, "p_nb": LOG_0, "score_asr": 0.0, "score_lm": 0.0, "score_len": 0.0, } beams = [beam] for t in range(eouts.size(1)): new_beams = [] _, v_topk = torch.topk( log_probs[:, t], k=min(beam_width, self.vocab_size), dim=-1, largest=True, sorted=True, ) if lm_weight > 0: # batchify hyps_batch = pad_sequence( [torch.tensor(beam["hyp"], device=eouts.device) for beam in beams], batch_first=True, ) hyp_lens_batch = torch.tensor( [len(beam["hyp"]) for beam in beams], device=eouts.device ) lm_log_prob_batch, _ = lm.predict( hyps_batch, hyp_lens_batch, states=None ) for b, beam in enumerate(beams): hyp = beam["hyp"] p_b = beam["p_b"] # end with blank p_nb = beam["p_nb"] # end with non-blank score_asr = beam["score_asr"] score_lm = beam["score_lm"] score_len = beam["score_len"] # case 1. hyp is not extended (copy the last) new_p_b = np.logaddexp( p_b + log_probs[0, t, self.blank_id].item(), p_nb + log_probs[0, t, self.blank_id].item(), ) if len(hyp) > 1: new_p_nb = p_nb + log_probs[0, t, hyp[-1]].item() else: new_p_nb = LOG_0 score_asr = np.logaddexp(new_p_b, new_p_nb) new_beams.append( { "hyp": hyp, "score": score_asr + score_lm + score_len, "p_b": new_p_b, "p_nb": new_p_nb, "score_asr": score_asr, "score_lm": score_lm, "score_len": score_len, } ) # case 2. hyp is extended new_p_b = LOG_0 for v in tensor2np(v_topk[0]): p_t = log_probs[0, t, v].item() if v == self.blank_id: continue v_prev = hyp[-1] if len(hyp) > 1 else None if v == v_prev: new_p_nb = p_b + p_t else: new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t) score_asr = np.logaddexp(new_p_b, new_p_nb) score_len = len_weight * (len(strip_eos(hyp, self.eos_id)) + 1) if lm_weight > 0: score_lm += lm_weight * lm_log_prob_batch[b, v].item() new_beams.append( { "hyp": hyp + [v], "score": score_asr + score_lm + score_len, "p_b": new_p_b, "p_nb": new_p_nb, "score_asr": score_asr, "score_lm": score_lm, "score_len": score_len, } ) # merge the same hyp new_beams = self._merge_ctc_paths(new_beams) beams = sorted(new_beams, key=lambda x: x["score"], reverse=True)[ :beam_width ] hyps = [beam["hyp"] for beam in beams] scores = [beam["score"] for beam in beams] return hyps, scores, logits
def decode( self, eouts, elens, eouts_inter=None, beam_width=1, len_weight=0, lm=None, lm_weight=0, decode_ctc_weight=0, decode_phone=False, ): """ Beam search decoding """ bs = eouts.size(0) if decode_ctc_weight == 1: print("CTC is used") # greedy return self.ctc.decode(eouts, elens, beam_width=1) assert bs == 1 # init beam = { "hyp": [self.eos_id], "score": 0.0, "score_ctc": 0.0, "ctc_state": None, "score_lm": 0.0, "lm_state": None, } if decode_ctc_weight > 0: ctc_logits = self.ctc(eouts, elens) ctc_log_probs = torch.log_softmax(ctc_logits, dim=-1) ctc_scorer = CTCPrefixScorer( tensor2np(ctc_log_probs.squeeze(0)), blank_id=self.blank_id, eos_id=self.eos_id, ) beam["score_ctc"] = 0.0 beam["ctc_state"] = ctc_scorer.initial_state() ctc_beam_width = min(ctc_log_probs.size(2), int(beam_width * CTC_BEAM_WIDTH_RATIO)) beams = [beam] results = [] for i in range(self.max_decode_ylen): new_beams = [] for beam in beams: ys_in = torch.tensor([beam["hyp"]]).to(eouts.device) ylens_in = torch.tensor([i + 1]).to(eouts.device) scores_att = torch.log_softmax(self.forward_one_step( ys_in, ylens_in, eouts), dim=-1) # (1, vocab) scores = scores_att if lm_weight > 0: scores_lm, _ = lm.predict(ys_in, ylens_in, states=None) # (1, vocab) scores += lm_weight * scores_lm[:, :self.vocab_size] if decode_ctc_weight > 0: score_ctc_prev = beam["score_ctc"] ctc_state_prev = beam["ctc_state"] scores_topb, v_topb = torch.topk(scores, k=ctc_beam_width, dim=1) scores_ctc, ctc_state = ctc_scorer(beam["hyp"], v_topb[0], ctc_state_prev) # re-calculate score scores = (1 - decode_ctc_weight) * scores_att[:, v_topb[ 0]] + decode_ctc_weight * np2tensor(scores_ctc - score_ctc_prev) if lm_weight > 0: scores += lm_weight * scores_lm[:, v_topb[0]] scores_topk, ids_topk = torch.topk(scores, k=beam_width, dim=1) v_topk = v_topb[:, ids_topk[0]] else: scores_topk, v_topk = torch.topk(scores, k=beam_width, dim=1) for j in range(beam_width): new_beam = {} new_beam["score"] = beam["score"] + float(scores_topk[0, j]) new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])] if decode_ctc_weight > 0: new_beam["score_ctc"] = scores_ctc[ids_topk[0, j]] new_beam["ctc_state"] = ctc_state[ids_topk[0, j]] new_beams.append(new_beam) # update `beams` beams = sorted(new_beams, key=lambda x: x["score"], reverse=True)[:beam_width] beams_extend = [] for beam in beams: # ended beams if beam["hyp"][-1] == self.eos_id: hyp_noeos = strip_eos(beam["hyp"], self.eos_id) # only <eos> is not acceptable if len(hyp_noeos) < 1: continue # add length penalty score = beam["score"] + len_weight * len(beam["hyp"]) results.append({"hyp": hyp_noeos, "score": score}) if len(results) >= beam_width: break else: beams_extend.append(beam) if len(results) >= beam_width: break beams = beams_extend results = sorted(results, key=lambda x: x["score"], reverse=True) hyps = [result["hyp"] for result in results] scores = [result["score"] for result in results] logits = None aligns = None return hyps, scores, logits, aligns
def test_step( model, lm, data, blank_id, mask_id, mask_th, device, vocab, vocab_size, vocab_phone=None, debug=False, ): utt_id = data["utt_ids"][0] xs = data["xs"].to(device) xlens = data["xlens"].to(device) reftext = data["texts"][0] # ASR (word) hyps, scores, logits, aligns = model.decode(xs, xlens, beam_width=0, len_weight=0) hyp = np.array(hyps[0]) if len(hyp) < 1: return utt_id, [], [], reftext, 0, 0 # ASR (phone) if vocab_phone is not None: hyps_phone, _, _, _ = model.decode( xs, xlens, beam_width=0, len_weight=0, decode_phone=True ) hyp_phone = np.array(hyps_phone[0]) if len(hyp_phone) < 1: return utt_id, [], [], reftext, 0, 0 hyp_masked = hyp.copy() token_probs, token_probs_v = aggregate_logits(logits[0], aligns[0], blank_id) assert len(hyp) == len(token_probs) assert len(hyp) == len(token_probs_v) # mask less confident tokens mask_indices = token_probs_v < mask_th hyp_masked[mask_indices] = mask_id num_masked = sum(mask_indices) num_tokens = len(mask_indices) y = np2tensor(hyp_masked) if vocab_phone is None: logits = lm(y.unsqueeze(0).to(device)) else: p = np2tensor(hyp_phone) logits = lm(y.unsqueeze(0).to(device), ps=p.unsqueeze(0).to(device)) lm_token_probs = tensor2np(torch.softmax(logits[0], dim=-1)) # fusion token_probs_mix = (1 - args.lm_weight) * token_probs[ :, :vocab_size ] + args.lm_weight * lm_token_probs[:, :vocab_size] y_gen = np.argmax(token_probs_mix, axis=-1) hyp_cor = hyp.copy() hyp_cor[mask_indices] = y_gen[mask_indices] if debug: print(f"*** {utt_id} ***") print(f"Ref.: {reftext}") print(f"Hyp.(word): {' '.join(vocab.ids2tokens(hyp))}") if vocab_phone is not None: print(f"Hyp.(phone): {' '.join(vocab_phone.ids2tokens(hyp_phone))}") print( f"Hyp.(masked): {' '.join(vocab.ids2tokens(hyp_masked))} ({num_masked:d}/{num_tokens:d} masked)" ) print("ASR probs:") token_probs_masked = token_probs[mask_indices] print_topk_probs(token_probs_masked, vocab=vocab) print("LM probs:") lm_token_probs_masked = lm_token_probs[mask_indices] print_topk_probs(lm_token_probs_masked, vocab=vocab) print(f"Hyp.(correct): {' '.join(vocab.ids2tokens(hyp_cor))}") return utt_id, hyp, hyp_cor, reftext, num_masked, num_tokens