def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
        """Alignment-length synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoded speech features (T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results

        """
        beam = min(self.beam_size, self.vocab_size)

        h_length = int(h.size(0))
        u_max = min(self.u_max, (h_length - 1))

        init_tensor = h.unsqueeze(0)
        beam_state = self.decoder.init_state(
            torch.zeros((beam, self.hidden_size)))

        B = [
            Hypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]
        final = []

        if self.lm:
            if hasattr(self.lm.predictor, "wordlm"):
                lm_model = self.lm.predictor.wordlm
                lm_type = "wordlm"
            else:
                lm_model = self.lm.predictor
                lm_type = "lm"

                B[0].lm_state = init_lm_state(lm_model)

            lm_layers = len(lm_model.rnn)

        cache = {}

        for i in range(h_length + u_max):
            A = []

            B_ = []
            h_states = []
            for hyp in B:
                u = len(hyp.yseq) - 1
                t = i - u + 1

                if t > (h_length - 1):
                    continue

                B_.append(hyp)
                h_states.append((t, h[t]))

            if B_:
                beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    B_, beam_state, cache, init_tensor)

                h_enc = torch.stack([h[1] for h in h_states])

                beam_logp = torch.log_softmax(self.decoder.joint_network(
                    h_enc, beam_y),
                                              dim=-1)
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                if self.lm:
                    beam_lm_states = create_lm_batch_state(
                        [b.lm_state for b in B_], lm_type, lm_layers)

                    beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(B_))

                for i, hyp in enumerate(B_):
                    new_hyp = Hypothesis(
                        score=(hyp.score + float(beam_logp[i, 0])),
                        yseq=hyp.yseq[:],
                        dec_state=hyp.dec_state,
                        lm_state=hyp.lm_state,
                    )

                    A.append(new_hyp)

                    if h_states[i][0] == (h_length - 1):
                        final.append(new_hyp)

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        new_hyp = Hypothesis(
                            score=(hyp.score + float(logp)),
                            yseq=(hyp.yseq[:] + [int(k)]),
                            dec_state=self.decoder.select_state(beam_state, i),
                            lm_state=hyp.lm_state,
                        )

                        if self.lm:
                            new_hyp.score += self.lm_weight * beam_lm_scores[i,
                                                                             k]

                            new_hyp.lm_state = select_lm_state(
                                beam_lm_states, i, lm_type, lm_layers)

                        A.append(new_hyp)

                B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
                B = recombine_hyps(B)

        if final:
            return self.sort_nbest(final)
        else:
            return B
    def nsc_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
        """N-step constrained beam search implementation.

        Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
        Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
        until further modifications.

        Note: the algorithm is not in his "complete" form but works almost as
        intended.

        Args:
            h: Encoded speech features (T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        init_tensor = h.unsqueeze(0)
        blank_tensor = init_tensor.new_zeros(1, dtype=torch.long)

        beam_state = self.decoder.init_state(
            torch.zeros((beam, self.hidden_size)))

        init_tokens = [
            Hypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens, beam_state, cache, init_tensor)

        state = self.decoder.select_state(beam_state, 0)

        if self.lm:
            beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                None, beam_lm_tokens, 1)

            if hasattr(self.lm.predictor, "wordlm"):
                lm_model = self.lm.predictor.wordlm
                lm_type = "wordlm"
            else:
                lm_model = self.lm.predictor
                lm_type = "lm"

            lm_layers = len(lm_model.rnn)

            lm_state = select_lm_state(beam_lm_states, 0, lm_type, lm_layers)
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            Hypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=state,
                y=[beam_y[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for hi in h:
            hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
            kept_hyps = []

            h_enc = hi.unsqueeze(0)

            for j in range(len(hyps) - 1):
                for i in range((j + 1), len(hyps)):
                    if (is_prefix(hyps[j].yseq, hyps[i].yseq)
                            and (len(hyps[j].yseq) - len(hyps[i].yseq)) <=
                            self.prefix_alpha):
                        next_id = len(hyps[i].yseq)

                        ytu = torch.log_softmax(self.decoder.joint_network(
                            hi, hyps[i].y[-1]),
                                                dim=0)

                        curr_score = hyps[i].score + float(
                            ytu[hyps[j].yseq[next_id]])

                        for k in range(next_id, (len(hyps[j].yseq) - 1)):
                            ytu = torch.log_softmax(self.decoder.joint_network(
                                hi, hyps[j].y[k]),
                                                    dim=0)

                            curr_score += float(ytu[hyps[j].yseq[k + 1]])

                        hyps[j].score = np.logaddexp(hyps[j].score, curr_score)

            S = []
            V = []
            for n in range(self.nstep):
                beam_y = torch.stack([hyp.y[-1] for hyp in hyps])

                beam_logp = torch.log_softmax(self.decoder.joint_network(
                    h_enc, beam_y),
                                              dim=-1)
                beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)

                if self.lm:
                    beam_lm_scores = torch.stack(
                        [hyp.lm_scores for hyp in hyps])

                for i, hyp in enumerate(hyps):
                    i_topk = (
                        torch.cat((beam_topk[0][i], beam_logp[i, 0:1])),
                        torch.cat((beam_topk[1][i] + 1, blank_tensor)),
                    )

                    for logp, k in zip(*i_topk):
                        new_hyp = Hypothesis(
                            yseq=hyp.yseq[:],
                            score=(hyp.score + float(logp)),
                            y=hyp.y[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        )

                        if k == self.blank:
                            S.append(new_hyp)
                        else:
                            new_hyp.yseq.append(int(k))

                            if self.lm:
                                new_hyp.score += self.lm_weight * float(
                                    beam_lm_scores[i, k])

                        V.append(new_hyp)

                V = sorted(V, key=lambda x: x.score, reverse=True)
                V = substract(V, hyps)[:beam]

                l_state = [v.dec_state for v in V]
                l_tokens = [v.yseq for v in V]

                beam_state = self.decoder.create_batch_states(
                    beam_state, l_state, l_tokens)
                beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    V, beam_state, cache, init_tensor)

                if self.lm:
                    beam_lm_states = create_lm_batch_state(
                        [v.lm_state for v in V], lm_type, lm_layers)
                    beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(V))

                if n < (self.nstep - 1):
                    for i, v in enumerate(V):
                        v.y.append(beam_y[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, lm_type, lm_layers)
                            v.lm_scores = beam_lm_scores[i]

                    hyps = V[:]
                else:
                    beam_logp = torch.log_softmax(self.decoder.joint_network(
                        h_enc, beam_y),
                                                  dim=-1)

                    for i, v in enumerate(V):
                        if self.nstep != 1:
                            v.score += float(beam_logp[i, 0])

                        v.y.append(beam_y[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, lm_type, lm_layers)
                            v.lm_scores = beam_lm_scores[i]

            kept_hyps = sorted((S + V), key=lambda x: x.score,
                               reverse=True)[:beam]

        return self.sort_nbest(kept_hyps)
예제 #3
0
    def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
        """N-step constrained beam search implementation.

        Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
        Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
        until further modifications.

        Note: the algorithm is not in his "complete" form but works almost as
        intended.

        Args:
            h: Encoded speech features (T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        beam_state = self.decoder.init_state(beam)

        init_tokens = [
            NSCHypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens,
            beam_state,
            cache,
            self.use_lm,
        )

        state = self.decoder.select_state(beam_state, 0)

        if self.use_lm:
            beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                None, beam_lm_tokens, 1)
            lm_state = select_lm_state(beam_lm_states, 0, self.lm_layers,
                                       self.is_wordlm)
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            NSCHypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=state,
                y=[beam_y[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for hi in h:
            hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
            kept_hyps = []

            h_enc = hi.unsqueeze(0)

            for j, hyp_j in enumerate(hyps[:-1]):
                for hyp_i in hyps[(j + 1):]:
                    curr_id = len(hyp_j.yseq)
                    next_id = len(hyp_i.yseq)

                    if (is_prefix(hyp_j.yseq, hyp_i.yseq)
                            and (curr_id - next_id) <= self.prefix_alpha):
                        ytu = torch.log_softmax(self.joint_network(
                            hi, hyp_i.y[-1]),
                                                dim=-1)

                        curr_score = hyp_i.score + float(
                            ytu[hyp_j.yseq[next_id]])

                        for k in range(next_id, (curr_id - 1)):
                            ytu = torch.log_softmax(self.joint_network(
                                hi, hyp_j.y[k]),
                                                    dim=-1)

                            curr_score += float(ytu[hyp_j.yseq[k + 1]])

                        hyp_j.score = np.logaddexp(hyp_j.score, curr_score)

            S = []
            V = []
            for n in range(self.nstep):
                beam_y = torch.stack([hyp.y[-1] for hyp in hyps])

                beam_logp = torch.log_softmax(self.joint_network(
                    h_enc, beam_y),
                                              dim=-1)
                beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)

                for i, hyp in enumerate(hyps):
                    S.append(
                        NSCHypothesis(
                            yseq=hyp.yseq[:],
                            score=hyp.score + float(beam_logp[i, 0:1]),
                            y=hyp.y[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        ))
                    V.append(S[-1])

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        score = hyp.score + float(logp)

                        if self.use_lm:
                            score += self.lm_weight * float(hyp.lm_scores[k])

                        V.append(
                            NSCHypothesis(
                                yseq=hyp.yseq[:] + [int(k)],
                                score=score,
                                y=hyp.y[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                                lm_scores=hyp.lm_scores,
                            ))

                V.sort(key=lambda x: x.score, reverse=True),
                V = substract(V, hyps)[:beam]

                beam_state = self.decoder.create_batch_states(
                    beam_state,
                    [v.dec_state for v in V],
                    [v.yseq for v in V],
                )
                beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    V,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                if self.use_lm:
                    beam_lm_states = create_lm_batch_state(
                        [v.lm_state for v in V], self.lm_layers,
                        self.is_wordlm)
                    beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(V))

                if n < (self.nstep - 1):
                    for i, v in enumerate(V):
                        v.y.append(beam_y[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, self.lm_layers,
                                self.is_wordlm)
                            v.lm_scores = beam_lm_scores[i]

                    hyps = V[:]
                else:
                    beam_logp = torch.log_softmax(self.joint_network(
                        h_enc, beam_y),
                                                  dim=-1)

                    for i, v in enumerate(V):
                        if self.nstep != 1:
                            v.score += float(beam_logp[i, 0])

                        v.y.append(beam_y[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, self.lm_layers,
                                self.is_wordlm)
                            v.lm_scores = beam_lm_scores[i]

            kept_hyps = sorted((S + V), key=lambda x: x.score,
                               reverse=True)[:beam]

        return self.sort_nbest(kept_hyps)
    def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
        """Time synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoded speech features (T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results

        """
        beam = min(self.beam_size, self.vocab_size)

        init_tensor = h.unsqueeze(0)
        beam_state = self.decoder.init_state(
            torch.zeros((beam, self.hidden_size)))

        B = [
            Hypothesis(
                yseq=[self.blank],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        if self.lm:
            if hasattr(self.lm.predictor, "wordlm"):
                lm_model = self.lm.predictor.wordlm
                lm_type = "wordlm"
            else:
                lm_model = self.lm.predictor
                lm_type = "lm"

                B[0].lm_state = init_lm_state(lm_model)

            lm_layers = len(lm_model.rnn)

        cache = {}

        for hi in h:
            A = []
            C = B

            h_enc = hi.unsqueeze(0)

            for v in range(self.max_sym_exp):
                D = []

                beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    C, beam_state, cache, init_tensor)

                beam_logp = torch.log_softmax(self.decoder.joint_network(
                    h_enc, beam_y),
                                              dim=-1)
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                seq_A = [h.yseq for h in A]

                for i, hyp in enumerate(C):
                    if hyp.yseq not in seq_A:
                        A.append(
                            Hypothesis(
                                score=(hyp.score + float(beam_logp[i, 0])),
                                yseq=hyp.yseq[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                            ))
                    else:
                        dict_pos = seq_A.index(hyp.yseq)

                        A[dict_pos].score = np.logaddexp(
                            A[dict_pos].score,
                            (hyp.score + float(beam_logp[i, 0])))

                if v < self.max_sym_exp:
                    if self.lm:
                        beam_lm_states = create_lm_batch_state(
                            [c.lm_state for c in C], lm_type, lm_layers)

                        beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                            beam_lm_states, beam_lm_tokens, len(C))

                    for i, hyp in enumerate(C):
                        for logp, k in zip(beam_topk[0][i],
                                           beam_topk[1][i] + 1):
                            new_hyp = Hypothesis(
                                score=(hyp.score + float(logp)),
                                yseq=(hyp.yseq + [int(k)]),
                                dec_state=self.decoder.select_state(
                                    beam_state, i),
                                lm_state=hyp.lm_state,
                            )

                            if self.lm:
                                new_hyp.score += self.lm_weight * beam_lm_scores[
                                    i, k]

                                new_hyp.lm_state = select_lm_state(
                                    beam_lm_states, i, lm_type, lm_layers)

                            D.append(new_hyp)

                C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]

            B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]

        return self.sort_nbest(B)
예제 #5
0
    def nsc_beam_search(self, enc_out: torch.Tensor) -> List[ExtendedHypothesis]:
        """N-step constrained beam search implementation.

        Based on/Modified from https://arxiv.org/pdf/2002.03577.pdf.
        Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
        until further modifications.

        Args:
            enc_out: Encoder output sequence. (T, D_enc)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        beam_state = self.decoder.init_state(beam)

        init_tokens = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens,
            beam_state,
            cache,
            self.use_lm,
        )

        state = self.decoder.select_state(beam_state, 0)

        if self.use_lm:
            beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                None, beam_lm_tokens, 1
            )
            lm_state = select_lm_state(
                beam_lm_states, 0, self.lm_layers, self.is_wordlm
            )
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=state,
                dec_out=[beam_dec_out[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for enc_out_t in enc_out:
            hyps = self.prefix_search(
                sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
                enc_out_t,
            )
            kept_hyps = []

            beam_enc_out = enc_out_t.unsqueeze(0)

            S = []
            V = []
            for n in range(self.nstep):
                beam_dec_out = torch.stack([hyp.dec_out[-1] for hyp in hyps])

                beam_logp = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out)
                    / self.softmax_temperature,
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)

                for i, hyp in enumerate(hyps):
                    S.append(
                        ExtendedHypothesis(
                            yseq=hyp.yseq[:],
                            score=hyp.score + float(beam_logp[i, 0:1]),
                            dec_out=hyp.dec_out[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        )
                    )

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        score = hyp.score + float(logp)

                        if self.use_lm:
                            score += self.lm_weight * float(hyp.lm_scores[k])

                        V.append(
                            ExtendedHypothesis(
                                yseq=hyp.yseq[:] + [int(k)],
                                score=score,
                                dec_out=hyp.dec_out[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                                lm_scores=hyp.lm_scores,
                            )
                        )

                V.sort(key=lambda x: x.score, reverse=True)
                V = subtract(V, hyps)[:beam]

                beam_state = self.decoder.create_batch_states(
                    beam_state,
                    [v.dec_state for v in V],
                    [v.yseq for v in V],
                )
                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    V,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                if self.use_lm:
                    beam_lm_states = create_lm_batch_states(
                        [v.lm_state for v in V], self.lm_layers, self.is_wordlm
                    )
                    beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(V)
                    )

                if n < (self.nstep - 1):
                    for i, v in enumerate(V):
                        v.dec_out.append(beam_dec_out[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, self.lm_layers, self.is_wordlm
                            )
                            v.lm_scores = beam_lm_scores[i]

                    hyps = V[:]
                else:
                    beam_logp = torch.log_softmax(
                        self.joint_network(beam_enc_out, beam_dec_out)
                        / self.softmax_temperature,
                        dim=-1,
                    )

                    for i, v in enumerate(V):
                        if self.nstep != 1:
                            v.score += float(beam_logp[i, 0])

                        v.dec_out.append(beam_dec_out[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = select_lm_state(
                                beam_lm_states, i, self.lm_layers, self.is_wordlm
                            )
                            v.lm_scores = beam_lm_scores[i]

            kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]

        return self.sort_nbest(kept_hyps)
예제 #6
0
    def modified_adaptive_expansion_search(
        self, enc_out: torch.Tensor
    ) -> List[ExtendedHypothesis]:
        """It's the modified Adaptive Expansion Search (mAES) implementation.

        Based on/modified from https://ieeexplore.ieee.org/document/9250505 and NSC.

        Args:
            enc_out: Encoder output sequence. (T, D_enc)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_state = self.decoder.init_state(beam)

        init_tokens = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens,
            beam_state,
            cache,
            self.use_lm,
        )

        state = self.decoder.select_state(beam_state, 0)

        if self.use_lm:
            beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                None, beam_lm_tokens, 1
            )
            lm_state = select_lm_state(
                beam_lm_states, 0, self.lm_layers, self.is_wordlm
            )
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=state,
                dec_out=[beam_dec_out[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for enc_out_t in enc_out:
            hyps = self.prefix_search(
                sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
                enc_out_t,
            )
            kept_hyps = []

            beam_enc_out = enc_out_t.unsqueeze(0)

            list_b = []
            for n in range(self.nstep):
                beam_dec_out = torch.stack([h.dec_out[-1] for h in hyps])

                beam_logp = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out)
                    / self.softmax_temperature,
                    dim=-1,
                )
                k_expansions = select_k_expansions(
                    hyps, beam_logp, beam, self.expansion_gamma, self.expansion_beta
                )

                list_exp = []
                for i, hyp in enumerate(hyps):
                    for k, new_score in k_expansions[i]:
                        new_hyp = ExtendedHypothesis(
                            yseq=hyp.yseq[:],
                            score=new_score,
                            dec_out=hyp.dec_out[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        )

                        if k == 0:
                            list_b.append(new_hyp)
                        else:
                            new_hyp.yseq.append(int(k))

                            if self.use_lm:
                                new_hyp.score += self.lm_weight * float(
                                    hyp.lm_scores[k]
                                )

                            list_exp.append(new_hyp)

                if not list_exp:
                    kept_hyps = sorted(list_b, key=lambda x: x.score, reverse=True)[
                        :beam
                    ]

                    break
                else:
                    beam_state = self.decoder.create_batch_states(
                        beam_state,
                        [hyp.dec_state for hyp in list_exp],
                        [hyp.yseq for hyp in list_exp],
                    )

                    beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                        list_exp,
                        beam_state,
                        cache,
                        self.use_lm,
                    )

                    if self.use_lm:
                        beam_lm_states = create_lm_batch_states(
                            [hyp.lm_state for hyp in list_exp],
                            self.lm_layers,
                            self.is_wordlm,
                        )
                        beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                            beam_lm_states, beam_lm_tokens, len(list_exp)
                        )

                    if n < (self.nstep - 1):
                        for i, hyp in enumerate(list_exp):
                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.select_state(beam_state, i)

                            if self.use_lm:
                                hyp.lm_state = select_lm_state(
                                    beam_lm_states, i, self.lm_layers, self.is_wordlm
                                )
                                hyp.lm_scores = beam_lm_scores[i]

                        hyps = list_exp[:]
                    else:
                        beam_logp = torch.log_softmax(
                            self.joint_network(beam_enc_out, beam_dec_out)
                            / self.softmax_temperature,
                            dim=-1,
                        )

                        for i, hyp in enumerate(list_exp):
                            hyp.score += float(beam_logp[i, 0])

                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.select_state(beam_state, i)

                            if self.use_lm:
                                hyp.lm_state = select_lm_state(
                                    beam_lm_states, i, self.lm_layers, self.is_wordlm
                                )
                                hyp.lm_scores = beam_lm_scores[i]

                        kept_hyps = sorted(
                            list_b + list_exp, key=lambda x: x.score, reverse=True
                        )[:beam]

        return self.sort_nbest(kept_hyps)
예제 #7
0
    def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Alignment-length synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoder output sequences. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)

        t_max = int(enc_out.size(0))
        u_max = min(self.u_max, (t_max - 1))

        beam_state = self.decoder.init_state(beam)

        B = [
            Hypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]
        final = []
        cache = {}

        if self.use_lm and not self.is_wordlm:
            B[0].lm_state = init_lm_state(self.lm_predictor)

        for i in range(t_max + u_max):
            A = []

            B_ = []
            B_enc_out = []
            for hyp in B:
                u = len(hyp.yseq) - 1
                t = i - u

                if t > (t_max - 1):
                    continue

                B_.append(hyp)
                B_enc_out.append((t, enc_out[t]))

            if B_:
                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    B_,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                beam_enc_out = torch.stack([x[1] for x in B_enc_out])

                beam_logp = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out)
                    / self.softmax_temperature,
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                if self.use_lm:
                    beam_lm_states = create_lm_batch_states(
                        [b.lm_state for b in B_], self.lm_layers, self.is_wordlm
                    )

                    beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(B_)
                    )

                for i, hyp in enumerate(B_):
                    new_hyp = Hypothesis(
                        score=(hyp.score + float(beam_logp[i, 0])),
                        yseq=hyp.yseq[:],
                        dec_state=hyp.dec_state,
                        lm_state=hyp.lm_state,
                    )

                    A.append(new_hyp)

                    if B_enc_out[i][0] == (t_max - 1):
                        final.append(new_hyp)

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        new_hyp = Hypothesis(
                            score=(hyp.score + float(logp)),
                            yseq=(hyp.yseq[:] + [int(k)]),
                            dec_state=self.decoder.select_state(beam_state, i),
                            lm_state=hyp.lm_state,
                        )

                        if self.use_lm:
                            new_hyp.score += self.lm_weight * beam_lm_scores[i, k]

                            new_hyp.lm_state = select_lm_state(
                                beam_lm_states, i, self.lm_layers, self.is_wordlm
                            )

                        A.append(new_hyp)

                B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
                B = recombine_hyps(B)

        if final:
            return self.sort_nbest(final)
        else:
            return B
예제 #8
0
    def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Time synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            enc_out: Encoder output sequence. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)

        beam_state = self.decoder.init_state(beam)

        B = [
            Hypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]
        cache = {}

        if self.use_lm and not self.is_wordlm:
            B[0].lm_state = init_lm_state(self.lm_predictor)

        for enc_out_t in enc_out:
            A = []
            C = B

            enc_out_t = enc_out_t.unsqueeze(0)

            for v in range(self.max_sym_exp):
                D = []

                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    C,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                beam_logp = torch.log_softmax(
                    self.joint_network(enc_out_t, beam_dec_out)
                    / self.softmax_temperature,
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                seq_A = [h.yseq for h in A]

                for i, hyp in enumerate(C):
                    if hyp.yseq not in seq_A:
                        A.append(
                            Hypothesis(
                                score=(hyp.score + float(beam_logp[i, 0])),
                                yseq=hyp.yseq[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                            )
                        )
                    else:
                        dict_pos = seq_A.index(hyp.yseq)

                        A[dict_pos].score = np.logaddexp(
                            A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
                        )

                if v < (self.max_sym_exp - 1):
                    if self.use_lm:
                        beam_lm_states = create_lm_batch_states(
                            [c.lm_state for c in C], self.lm_layers, self.is_wordlm
                        )

                        beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                            beam_lm_states, beam_lm_tokens, len(C)
                        )

                    for i, hyp in enumerate(C):
                        for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                            new_hyp = Hypothesis(
                                score=(hyp.score + float(logp)),
                                yseq=(hyp.yseq + [int(k)]),
                                dec_state=self.decoder.select_state(beam_state, i),
                                lm_state=hyp.lm_state,
                            )

                            if self.use_lm:
                                new_hyp.score += self.lm_weight * beam_lm_scores[i, k]

                                new_hyp.lm_state = select_lm_state(
                                    beam_lm_states, i, self.lm_layers, self.is_wordlm
                                )

                            D.append(new_hyp)

                C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]

            B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]

        return self.sort_nbest(B)
예제 #9
0
def nsc_beam_search(decoder, h, recog_args, rnnlm=None):
    """N-step constrained beam search implementation.

    Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
    Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
    until further modifications.

    Note: the algorithm is not in his "complete" form but works almost as
          intended.

    Args:
        decoder (class): decoder class
        h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language module

    Returns:
        nbest_hyps (list of dicts): n-best decoding results

    """
    beam = min(recog_args.beam_size, decoder.odim)
    beam_k = min(beam, (decoder.odim - 1))

    nstep = recog_args.nstep
    prefix_alpha = recog_args.prefix_alpha

    nbest = recog_args.nbest

    cache = {}

    init_tensor = h.unsqueeze(0)
    blank_tensor = init_tensor.new_zeros(1, dtype=torch.long)

    beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits)))

    init_tokens = [
        Hypothesis(
            yseq=[decoder.blank],
            score=0.0,
            dec_state=decoder.select_state(beam_state, 0),
        )
    ]

    beam_y, beam_state, beam_lm_tokens = decoder.batch_score(
        init_tokens, beam_state, cache, init_tensor)

    state = decoder.select_state(beam_state, 0)

    if rnnlm:
        beam_lm_states, beam_lm_scores = rnnlm.buff_predict(
            None, beam_lm_tokens, 1)

        if hasattr(rnnlm.predictor, "wordlm"):
            lm_model = rnnlm.predictor.wordlm
            lm_type = "wordlm"
        else:
            lm_model = rnnlm.predictor
            lm_type = "lm"

        lm_layers = len(lm_model.rnn)

        lm_state = select_lm_state(beam_lm_states, 0, lm_type, lm_layers)
        lm_scores = beam_lm_scores[0]
    else:
        lm_state = None
        lm_scores = None

    kept_hyps = [
        Hypothesis(
            yseq=[decoder.blank],
            score=0.0,
            dec_state=state,
            y=[beam_y[0]],
            lm_state=lm_state,
            lm_scores=lm_scores,
        )
    ]

    for hi in h:
        hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
        kept_hyps = []

        h_enc = hi.unsqueeze(0)

        for j in range(len(hyps) - 1):
            for i in range((j + 1), len(hyps)):
                if (is_prefix(hyps[j].yseq, hyps[i].yseq) and
                    (len(hyps[j].yseq) - len(hyps[i].yseq)) <= prefix_alpha):
                    next_id = len(hyps[i].yseq)

                    ytu = F.log_softmax(decoder.joint(hi, hyps[i].y[-1]),
                                        dim=0)

                    curr_score = hyps[i].score + float(
                        ytu[hyps[j].yseq[next_id]])

                    for k in range(next_id, (len(hyps[j].yseq) - 1)):
                        ytu = F.log_softmax(decoder.joint(hi, hyps[j].y[k]),
                                            dim=0)

                        curr_score += float(ytu[hyps[j].yseq[k + 1]])

                    hyps[j].score = np.logaddexp(hyps[j].score, curr_score)

        S = []
        V = []
        for n in range(nstep):
            beam_y = torch.stack([hyp.y[-1] for hyp in hyps])

            beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1)
            beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)

            if rnnlm:
                beam_lm_scores = torch.stack([hyp.lm_scores for hyp in hyps])

            for i, hyp in enumerate(hyps):
                i_topk = (
                    torch.cat((beam_topk[0][i], beam_logp[i, 0:1])),
                    torch.cat((beam_topk[1][i] + 1, blank_tensor)),
                )

                for logp, k in zip(*i_topk):
                    new_hyp = Hypothesis(
                        yseq=hyp.yseq[:],
                        score=(hyp.score + float(logp)),
                        y=hyp.y[:],
                        dec_state=hyp.dec_state,
                        lm_state=hyp.lm_state,
                        lm_scores=hyp.lm_scores,
                    )

                    if k == decoder.blank:
                        S.append(new_hyp)
                    else:
                        new_hyp.yseq.append(int(k))

                        if rnnlm:
                            new_hyp.score += recog_args.lm_weight * float(
                                beam_lm_scores[i, k])

                        V.append(new_hyp)

            V = sorted(V, key=lambda x: x.score, reverse=True)
            V = substract(V, hyps)[:beam]

            l_state = [v.dec_state for v in V]
            l_tokens = [v.yseq for v in V]

            beam_state = decoder.create_batch_states(beam_state, l_state,
                                                     l_tokens)
            beam_y, beam_state, beam_lm_tokens = decoder.batch_score(
                V, beam_state, cache, init_tensor)

            if rnnlm:
                beam_lm_states = create_lm_batch_state([v.lm_state for v in V],
                                                       lm_type, lm_layers)
                beam_lm_states, beam_lm_scores = rnnlm.buff_predict(
                    beam_lm_states, beam_lm_tokens, len(V))

            if n < (nstep - 1):
                for i, v in enumerate(V):
                    v.y.append(beam_y[i])

                    v.dec_state = decoder.select_state(beam_state, i)

                    if rnnlm:
                        v.lm_state = select_lm_state(beam_lm_states, i,
                                                     lm_type, lm_layers)
                        v.lm_scores = beam_lm_scores[i]

                hyps = V[:]
            else:
                beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1)

                for i, v in enumerate(V):
                    if nstep != 1:
                        v.score += float(beam_logp[i, 0])

                    v.y.append(beam_y[i])

                    v.dec_state = decoder.select_state(beam_state, i)

                    if rnnlm:
                        v.lm_state = select_lm_state(beam_lm_states, i,
                                                     lm_type, lm_layers)
                        v.lm_scores = beam_lm_scores[i]

        kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]

    nbest_hyps = sorted(kept_hyps,
                        key=lambda x: (x.score / len(x.yseq)),
                        reverse=True)[:nbest]

    return [asdict(n) for n in nbest_hyps]
예제 #10
0
def align_length_sync_decoding(decoder, h, recog_args, rnnlm=None):
    """Alignment-length synchronous beam search implementation.

    Based on https://ieeexplore.ieee.org/document/9053040

    Args:
        decoder (class): decoder class
        h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language module

    Returns:
        nbest_hyps (list of dicts): n-best decoding results

    """
    beam = min(recog_args.beam_size, decoder.odim)

    h_length = int(h.size(0))
    u_max = min(recog_args.u_max, (h_length - 1))

    nbest = recog_args.nbest

    init_tensor = h.unsqueeze(0)

    beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits)))

    B = [
        Hypothesis(
            yseq=[decoder.blank],
            score=0.0,
            dec_state=decoder.select_state(beam_state, 0),
        )
    ]
    final = []

    if rnnlm:
        if hasattr(rnnlm.predictor, "wordlm"):
            lm_model = rnnlm.predictor.wordlm
            lm_type = "wordlm"
        else:
            lm_model = rnnlm.predictor
            lm_type = "lm"

            B[0].lm_state = init_lm_state(lm_model)

        lm_layers = len(lm_model.rnn)

    cache = {}

    for i in range(h_length + u_max):
        A = []

        B_ = []
        h_states = []
        for hyp in B:
            u = len(hyp.yseq) - 1
            t = i - u + 1

            if t > (h_length - 1):
                continue

            B_.append(hyp)
            h_states.append((t, h[t]))

        if B_:
            beam_y, beam_state, beam_lm_tokens = decoder.batch_score(
                B_, beam_state, cache, init_tensor)

            h_enc = torch.stack([h[1] for h in h_states])

            beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1)
            beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

            if rnnlm:
                beam_lm_states = create_lm_batch_state(
                    [b.lm_state for b in B_], lm_type, lm_layers)

                beam_lm_states, beam_lm_scores = rnnlm.buff_predict(
                    beam_lm_states, beam_lm_tokens, len(B_))

            for i, hyp in enumerate(B_):
                new_hyp = Hypothesis(
                    score=(hyp.score + float(beam_logp[i, 0])),
                    yseq=hyp.yseq[:],
                    dec_state=hyp.dec_state,
                    lm_state=hyp.lm_state,
                )

                A.append(new_hyp)

                if h_states[i][0] == (h_length - 1):
                    final.append(new_hyp)

                for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                    new_hyp = Hypothesis(
                        score=(hyp.score + float(logp)),
                        yseq=(hyp.yseq[:] + [int(k)]),
                        dec_state=decoder.select_state(beam_state, i),
                        lm_state=hyp.lm_state,
                    )

                    if rnnlm:
                        new_hyp.score += recog_args.lm_weight * beam_lm_scores[
                            i, k]

                        new_hyp.lm_state = select_lm_state(
                            beam_lm_states, i, lm_type, lm_layers)

                    A.append(new_hyp)

            B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
            B = recombine_hyps(B)

    if final:
        nbest_hyps = sorted(final, key=lambda x: x.score, reverse=True)[:nbest]
    else:
        nbest_hyps = B[:nbest]

    return [asdict(n) for n in nbest_hyps]
예제 #11
0
def time_sync_decoding(decoder, h, recog_args, rnnlm=None):
    """Time synchronous beam search implementation.

    Based on https://ieeexplore.ieee.org/document/9053040

    Args:
        decoder (class): decoder class
        h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language module

    Returns:
        nbest_hyps (list of dicts): n-best decoding results

    """
    beam = min(recog_args.beam_size, decoder.odim)

    max_sym_exp = recog_args.max_sym_exp
    nbest = recog_args.nbest

    init_tensor = h.unsqueeze(0)

    beam_state = decoder.init_state(torch.zeros((beam, decoder.dunits)))

    B = [
        Hypothesis(
            yseq=[decoder.blank],
            score=0.0,
            dec_state=decoder.select_state(beam_state, 0),
        )
    ]

    if rnnlm:
        if hasattr(rnnlm.predictor, "wordlm"):
            lm_model = rnnlm.predictor.wordlm
            lm_type = "wordlm"
        else:
            lm_model = rnnlm.predictor
            lm_type = "lm"

            B[0].lm_state = init_lm_state(lm_model)

        lm_layers = len(lm_model.rnn)

    cache = {}

    for hi in h:
        A = []
        C = B

        h_enc = hi.unsqueeze(0)

        for v in range(max_sym_exp):
            D = []

            beam_y, beam_state, beam_lm_tokens = decoder.batch_score(
                C, beam_state, cache, init_tensor)

            beam_logp = F.log_softmax(decoder.joint(h_enc, beam_y), dim=-1)
            beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

            seq_A = [h.yseq for h in A]

            for i, hyp in enumerate(C):
                if hyp.yseq not in seq_A:
                    A.append(
                        Hypothesis(
                            score=(hyp.score + float(beam_logp[i, 0])),
                            yseq=hyp.yseq[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                        ))
                else:
                    dict_pos = seq_A.index(hyp.yseq)

                    A[dict_pos].score = np.logaddexp(
                        A[dict_pos].score,
                        (hyp.score + float(beam_logp[i, 0])))

            if v < max_sym_exp:
                if rnnlm:
                    beam_lm_states = create_lm_batch_state(
                        [c.lm_state for c in C], lm_type, lm_layers)

                    beam_lm_states, beam_lm_scores = rnnlm.buff_predict(
                        beam_lm_states, beam_lm_tokens, len(C))

                for i, hyp in enumerate(C):
                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        new_hyp = Hypothesis(
                            score=(hyp.score + float(logp)),
                            yseq=(hyp.yseq + [int(k)]),
                            dec_state=decoder.select_state(beam_state, i),
                            lm_state=hyp.lm_state,
                        )

                        if rnnlm:
                            new_hyp.score += recog_args.lm_weight * beam_lm_scores[
                                i, k]

                            new_hyp.lm_state = select_lm_state(
                                beam_lm_states, i, lm_type, lm_layers)

                        D.append(new_hyp)

            C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]

        B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]

    nbest_hyps = sorted(B, key=lambda x: x.score, reverse=True)[:nbest]

    return [asdict(n) for n in nbest_hyps]