예제 #1
0
def aggregate_logits(logits, aligns, blank_id, reduction="max"):
    assert logits.size(0) == len(aligns)
    xlen = logits.size(0)

    token_probs_allv = []
    token_probs_allv_tmp = []
    token_probs_v = []

    token_id_prev = None

    for t in range(xlen):
        token_id = aligns[t]

        if token_id == blank_id:
            continue

        if token_id != aligns[t - 1] and token_id_prev is not None:
            token_probs_allv_tmp = np.array(token_probs_allv_tmp)
            if reduction == "max":
                index = np.argmax(token_probs_allv_tmp[:, token_id_prev])
                token_probs_allv.append(token_probs_allv_tmp[index])
                token_probs_v.append(token_probs_allv_tmp[index, token_id_prev])
            token_probs_allv_tmp = []

        token_probs_allv_tmp.append(tensor2np(torch.softmax(logits[t], dim=-1)))
        token_id_prev = token_id

    token_probs_allv_tmp = np.array(token_probs_allv_tmp)
    if reduction == "max":
        index = np.argmax(token_probs_allv_tmp[:, token_id_prev])
        token_probs_allv.append(token_probs_allv_tmp[index])
        token_probs_v.append(token_probs_allv_tmp[index, token_id_prev])

    return np.array(token_probs_allv), np.array(token_probs_v)
예제 #2
0
def ppl_masked_lm(dataloader, model, device, mask_id, max_seq_len, vocab):
    cnt = 0
    sum_logprob = 0

    for i, data in enumerate(dataloader):
        if (i + 1) % LOG_STEP == 0:
            logging.info(
                f"{(i+1):>4} / {len(dataloader):>4} PPL: {math.exp(sum_logprob/cnt):.3f}"
            )
        utt_id = data["utt_ids"][0]
        ys = data["ys_in"].to(device)  # not masked
        ylens = data["ylens"].to(device)
        assert ys.size(0) == 1

        # for P2WDataset
        ps = data["ps"].to(device) if "ps" in data else None
        plens = data["plens"].to(device) if "plens" in data else None

        if ys.size(1) > max_seq_len:
            logging.warning(f"input length longer than {max_seq_len:d} skip")
            continue

        if args.print_probs:
            print("********************")
            print(f"{utt_id}: {vocab.ids2text(tensor2np(ys[0]))}")

        for mask_pos in range(ys.size(1)):
            ys_masked = ys.clone()
            label = ys[0, mask_pos]
            ys_masked[0, mask_pos] = mask_id

            with torch.no_grad():
                logits = model(ys_masked,
                               ylens,
                               labels=None,
                               ps=ps,
                               plens=plens)

            if args.print_probs:
                print(vocab.ids2text(tensor2np(ys_masked[0])))
                # TODO: print phones

                probs = torch.softmax(logits, dim=-1)
                p_topk, v_topk = torch.topk(probs[0, mask_pos], k=5)
                print(f"{vocab.i2t[label.item()]} || " + " | ".join([
                    f"{vocab.i2t[v.item()]}: {p.item():.2f}"
                    for p, v in zip(p_topk, v_topk)
                ]))

            logprobs = torch.log_softmax(logits, dim=-1)
            sum_logprob -= logprobs[0, mask_pos, label].item()
            cnt += 1

    ppl = math.exp(sum_logprob / cnt)

    return cnt, ppl
예제 #3
0
    def predict(self, ys, ylens, states=None):
        """ predict next token for Shallow Fusion
        """
        attention_mask = make_nopad_mask(ylens).float().to(ys.device)

        with torch.no_grad():
            (logits,) = self.transformer(ys, attention_mask, causal=True)

        log_probs = torch.log_softmax(logits, dim=-1)

        log_probs_next = []
        bs = len(ys)
        for b in range(bs):
            log_probs_next.append(tensor2np(log_probs[b, ylens[b] - 1]))

        return torch.tensor(log_probs_next).to(ys.device), states
예제 #4
0
def save_feats(wav_path):
    with torch.no_grad():
        wav, sr = torchaudio.load(wav_path)
        assert sr == 16000
        wav *= 2**15  # kaldi
        lmfb = torchaudio.compliance.kaldi.fbank(
            wav,
            window_type="hamming",
            htk_compat=True,
            sample_frequency=16000,
            num_mel_bins=80,
            use_energy=False,
        )
        lmfb = tensor2np(lmfb)

        npy_path = wav_path.replace(".wav", ".npy")
        np.save(npy_path, lmfb)

        lmfb_sum = np.sum(lmfb, axis=0)
        lmfb_sqsum = np.sum(lmfb * lmfb, axis=0)
        num_frames = lmfb.shape[0]

    return lmfb_sum, lmfb_sqsum, num_frames
예제 #5
0
def make_lm_label(
    df,
    model,
    device,
    save_path,
    topk=8,
    temp=3.0,
    add_sos_eos=False,
    eos_id=2,
    max_seq_len=256,
):
    labels = {}

    utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], []  # batch

    for i, row in enumerate(df.itertuples()):
        ids = str2ints(row.token_id)

        if add_sos_eos:
            if len(ids) <= max_seq_len - 2:
                ids = [eos_id] + ids + [eos_id]
                start_pos = row.start_pos + 1
                end_pos = row.end_pos + 1
            else:
                # reduce context
                ids = [eos_id] + ids[1:-1] + [eos_id]
                start_pos = row.start_pos
                end_pos = row.end_pos
        else:
            start_pos = row.start_pos
            end_pos = row.end_pos

        y = torch.tensor(ids)
        ylen = len(ids)

        utt_ids.append(row.utt_id)
        ys.append(y)
        ylens.append(ylen)
        start_poss.append(start_pos)
        end_poss.append(end_pos)

        # batchify
        if (i + 1) % BATCH_SIZE == 0 or (i + 1) == len(df):
            bs = len(ys)
            ys_pad = pad_sequence(ys, batch_first=True).to(device)
            ylens = torch.tensor(ylens).to(device)

            with torch.no_grad():
                logits = model(ys_pad, ylens)

            for b in range(bs):
                utt_id = utt_ids[b]
                start_pos = start_poss[b]
                end_pos = end_poss[b]
                y = ys[b]

                for pos in range(start_pos, end_pos):
                    if pos == 0:
                        v_topk = np.array([y[pos]])
                        p_topk = np.array([1.0])
                        logging.warning(f"hard label is used: {v_topk}")
                    else:
                        o_sorted, v_sorted = torch.sort(logits[b, pos - 1],
                                                        descending=True)
                        o_topk = o_sorted[:topk]
                        v_topk = tensor2np(v_sorted[:topk])
                        p_topk = tensor2np(
                            torch.softmax((o_topk / temp), dim=0))

                    label = []
                    for v, p in zip(v_topk, p_topk):
                        # NOTE: do not add <eos> to soft labels
                        if add_sos_eos and v == eos_id:
                            continue
                        label.append((v, p))

                    if utt_id not in labels:  # first token in utterance
                        labels[utt_id] = [label]
                    else:
                        labels[utt_id].append(label)

            utt_ids, ys, ylens, start_poss, end_poss = [], [], [], [], []

        if (i + 1) % LOG_STEP == 0:
            logging.info(f"{(i+1):>4} / {len(df):>4}")
        if (i + 1) == SAVE_STEP:
            save_tmp_path = save_path + ".tmp"
            with open(save_tmp_path, "wb") as f:
                pickle.dump(labels, f)
            logging.info(f"pickle is saved to {save_tmp_path}")

    with open(save_path, "wb") as f:
        pickle.dump(labels, f)
    logging.info(f"pickle is saved to {save_path}")
예제 #6
0
    def _beam_search(
        self, eouts, elens, beam_width=1, len_weight=0, lm=None, lm_weight=0
    ):
        """ Beam search decoding

        Reference:
            https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7
        """
        bs = eouts.size(0)
        assert bs == 1

        logits = self.output(eouts)
        log_probs = torch.log_softmax(logits, dim=-1)

        # init
        beam = {
            "hyp": [self.eos_id],  # <eos> is used for LM
            "score": 0.0,
            "p_b": 0.0,
            "p_nb": LOG_0,
            "score_asr": 0.0,
            "score_lm": 0.0,
            "score_len": 0.0,
        }
        beams = [beam]

        for t in range(eouts.size(1)):
            new_beams = []

            _, v_topk = torch.topk(
                log_probs[:, t],
                k=min(beam_width, self.vocab_size),
                dim=-1,
                largest=True,
                sorted=True,
            )

            if lm_weight > 0:
                # batchify
                hyps_batch = pad_sequence(
                    [torch.tensor(beam["hyp"], device=eouts.device) for beam in beams],
                    batch_first=True,
                )
                hyp_lens_batch = torch.tensor(
                    [len(beam["hyp"]) for beam in beams], device=eouts.device
                )
                lm_log_prob_batch, _ = lm.predict(
                    hyps_batch, hyp_lens_batch, states=None
                )

            for b, beam in enumerate(beams):
                hyp = beam["hyp"]
                p_b = beam["p_b"]  # end with blank
                p_nb = beam["p_nb"]  # end with non-blank
                score_asr = beam["score_asr"]
                score_lm = beam["score_lm"]
                score_len = beam["score_len"]

                # case 1. hyp is not extended (copy the last)
                new_p_b = np.logaddexp(
                    p_b + log_probs[0, t, self.blank_id].item(),
                    p_nb + log_probs[0, t, self.blank_id].item(),
                )
                if len(hyp) > 1:
                    new_p_nb = p_nb + log_probs[0, t, hyp[-1]].item()
                else:
                    new_p_nb = LOG_0
                score_asr = np.logaddexp(new_p_b, new_p_nb)

                new_beams.append(
                    {
                        "hyp": hyp,
                        "score": score_asr + score_lm + score_len,
                        "p_b": new_p_b,
                        "p_nb": new_p_nb,
                        "score_asr": score_asr,
                        "score_lm": score_lm,
                        "score_len": score_len,
                    }
                )

                # case 2. hyp is extended
                new_p_b = LOG_0
                for v in tensor2np(v_topk[0]):
                    p_t = log_probs[0, t, v].item()
                    if v == self.blank_id:
                        continue
                    v_prev = hyp[-1] if len(hyp) > 1 else None
                    if v == v_prev:
                        new_p_nb = p_b + p_t
                    else:
                        new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t)

                    score_asr = np.logaddexp(new_p_b, new_p_nb)
                    score_len = len_weight * (len(strip_eos(hyp, self.eos_id)) + 1)
                    if lm_weight > 0:
                        score_lm += lm_weight * lm_log_prob_batch[b, v].item()

                    new_beams.append(
                        {
                            "hyp": hyp + [v],
                            "score": score_asr + score_lm + score_len,
                            "p_b": new_p_b,
                            "p_nb": new_p_nb,
                            "score_asr": score_asr,
                            "score_lm": score_lm,
                            "score_len": score_len,
                        }
                    )

            # merge the same hyp
            new_beams = self._merge_ctc_paths(new_beams)
            beams = sorted(new_beams, key=lambda x: x["score"], reverse=True)[
                :beam_width
            ]

        hyps = [beam["hyp"] for beam in beams]
        scores = [beam["score"] for beam in beams]

        return hyps, scores, logits
예제 #7
0
    def decode(
        self,
        eouts,
        elens,
        eouts_inter=None,
        beam_width=1,
        len_weight=0,
        lm=None,
        lm_weight=0,
        decode_ctc_weight=0,
        decode_phone=False,
    ):
        """ Beam search decoding
        """
        bs = eouts.size(0)
        if decode_ctc_weight == 1:
            print("CTC is used")
            # greedy
            return self.ctc.decode(eouts, elens, beam_width=1)

        assert bs == 1

        # init
        beam = {
            "hyp": [self.eos_id],
            "score": 0.0,
            "score_ctc": 0.0,
            "ctc_state": None,
            "score_lm": 0.0,
            "lm_state": None,
        }
        if decode_ctc_weight > 0:
            ctc_logits = self.ctc(eouts, elens)
            ctc_log_probs = torch.log_softmax(ctc_logits, dim=-1)

            ctc_scorer = CTCPrefixScorer(
                tensor2np(ctc_log_probs.squeeze(0)),
                blank_id=self.blank_id,
                eos_id=self.eos_id,
            )
            beam["score_ctc"] = 0.0
            beam["ctc_state"] = ctc_scorer.initial_state()
            ctc_beam_width = min(ctc_log_probs.size(2),
                                 int(beam_width * CTC_BEAM_WIDTH_RATIO))
        beams = [beam]

        results = []

        for i in range(self.max_decode_ylen):
            new_beams = []

            for beam in beams:
                ys_in = torch.tensor([beam["hyp"]]).to(eouts.device)
                ylens_in = torch.tensor([i + 1]).to(eouts.device)

                scores_att = torch.log_softmax(self.forward_one_step(
                    ys_in, ylens_in, eouts),
                                               dim=-1)  # (1, vocab)
                scores = scores_att

                if lm_weight > 0:
                    scores_lm, _ = lm.predict(ys_in, ylens_in,
                                              states=None)  # (1, vocab)
                    scores += lm_weight * scores_lm[:, :self.vocab_size]

                if decode_ctc_weight > 0:
                    score_ctc_prev = beam["score_ctc"]
                    ctc_state_prev = beam["ctc_state"]
                    scores_topb, v_topb = torch.topk(scores,
                                                     k=ctc_beam_width,
                                                     dim=1)
                    scores_ctc, ctc_state = ctc_scorer(beam["hyp"], v_topb[0],
                                                       ctc_state_prev)
                    # re-calculate score
                    scores = (1 - decode_ctc_weight) * scores_att[:, v_topb[
                        0]] + decode_ctc_weight * np2tensor(scores_ctc -
                                                            score_ctc_prev)
                    if lm_weight > 0:
                        scores += lm_weight * scores_lm[:, v_topb[0]]
                    scores_topk, ids_topk = torch.topk(scores,
                                                       k=beam_width,
                                                       dim=1)
                    v_topk = v_topb[:, ids_topk[0]]
                else:
                    scores_topk, v_topk = torch.topk(scores,
                                                     k=beam_width,
                                                     dim=1)

                for j in range(beam_width):
                    new_beam = {}
                    new_beam["score"] = beam["score"] + float(scores_topk[0,
                                                                          j])
                    new_beam["hyp"] = beam["hyp"] + [int(v_topk[0, j])]
                    if decode_ctc_weight > 0:
                        new_beam["score_ctc"] = scores_ctc[ids_topk[0, j]]
                        new_beam["ctc_state"] = ctc_state[ids_topk[0, j]]
                    new_beams.append(new_beam)

            # update `beams`
            beams = sorted(new_beams, key=lambda x: x["score"],
                           reverse=True)[:beam_width]

            beams_extend = []
            for beam in beams:
                # ended beams
                if beam["hyp"][-1] == self.eos_id:
                    hyp_noeos = strip_eos(beam["hyp"], self.eos_id)
                    # only <eos> is not acceptable
                    if len(hyp_noeos) < 1:
                        continue

                    # add length penalty
                    score = beam["score"] + len_weight * len(beam["hyp"])

                    results.append({"hyp": hyp_noeos, "score": score})

                    if len(results) >= beam_width:
                        break
                else:
                    beams_extend.append(beam)

            if len(results) >= beam_width:
                break

            beams = beams_extend

        results = sorted(results, key=lambda x: x["score"], reverse=True)
        hyps = [result["hyp"] for result in results]
        scores = [result["score"] for result in results]
        logits = None
        aligns = None

        return hyps, scores, logits, aligns
예제 #8
0
def test_step(
    model,
    lm,
    data,
    blank_id,
    mask_id,
    mask_th,
    device,
    vocab,
    vocab_size,
    vocab_phone=None,
    debug=False,
):
    utt_id = data["utt_ids"][0]
    xs = data["xs"].to(device)
    xlens = data["xlens"].to(device)
    reftext = data["texts"][0]

    # ASR (word)
    hyps, scores, logits, aligns = model.decode(xs, xlens, beam_width=0, len_weight=0)
    hyp = np.array(hyps[0])

    if len(hyp) < 1:
        return utt_id, [], [], reftext, 0, 0

    # ASR (phone)
    if vocab_phone is not None:
        hyps_phone, _, _, _ = model.decode(
            xs, xlens, beam_width=0, len_weight=0, decode_phone=True
        )
        hyp_phone = np.array(hyps_phone[0])

        if len(hyp_phone) < 1:
            return utt_id, [], [], reftext, 0, 0

    hyp_masked = hyp.copy()
    token_probs, token_probs_v = aggregate_logits(logits[0], aligns[0], blank_id)
    assert len(hyp) == len(token_probs)
    assert len(hyp) == len(token_probs_v)

    # mask less confident tokens
    mask_indices = token_probs_v < mask_th
    hyp_masked[mask_indices] = mask_id

    num_masked = sum(mask_indices)
    num_tokens = len(mask_indices)

    y = np2tensor(hyp_masked)

    if vocab_phone is None:
        logits = lm(y.unsqueeze(0).to(device))
    else:
        p = np2tensor(hyp_phone)
        logits = lm(y.unsqueeze(0).to(device), ps=p.unsqueeze(0).to(device))

    lm_token_probs = tensor2np(torch.softmax(logits[0], dim=-1))

    # fusion
    token_probs_mix = (1 - args.lm_weight) * token_probs[
        :, :vocab_size
    ] + args.lm_weight * lm_token_probs[:, :vocab_size]

    y_gen = np.argmax(token_probs_mix, axis=-1)

    hyp_cor = hyp.copy()
    hyp_cor[mask_indices] = y_gen[mask_indices]

    if debug:
        print(f"*** {utt_id} ***")
        print(f"Ref.: {reftext}")
        print(f"Hyp.(word): {' '.join(vocab.ids2tokens(hyp))}")
        if vocab_phone is not None:
            print(f"Hyp.(phone): {' '.join(vocab_phone.ids2tokens(hyp_phone))}")
        print(
            f"Hyp.(masked): {' '.join(vocab.ids2tokens(hyp_masked))} ({num_masked:d}/{num_tokens:d} masked)"
        )
        print("ASR probs:")
        token_probs_masked = token_probs[mask_indices]
        print_topk_probs(token_probs_masked, vocab=vocab)
        print("LM probs:")
        lm_token_probs_masked = lm_token_probs[mask_indices]
        print_topk_probs(lm_token_probs_masked, vocab=vocab)
        print(f"Hyp.(correct): {' '.join(vocab.ids2tokens(hyp_cor))}")

    return utt_id, hyp, hyp_cor, reftext, num_masked, num_tokens