Esempio n. 1
0
    def ctc_decoder_predictions_tensor(
        self,
        predictions: torch.Tensor,
        predictions_len: torch.Tensor = None,
        return_hypotheses: bool = False,
    ) -> List[str]:
        """
        Decodes a sequence of labels to words

        Args:
            predictions: An integer torch.Tensor of shape [Batch, Time] (if ``batch_index_dim == 0``) or [Time, Batch]
                (if ``batch_index_dim == 1``) of integer indices that correspond to the index of some character in the
                label set.
            predictions_len: Optional tensor of length `Batch` which contains the integer lengths
                of the sequence in the padded `predictions` tensor.
            return_hypotheses: Bool flag whether to return just the decoding predictions of the model
                or a Hypothesis object that holds information such as the decoded `text`,
                the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available).
                May also contain the log-probabilities of the decoder (if this method is called via
                transcribe())

        Returns:
            Either a list of str which represent the CTC decoded strings per sample,
            or a list of Hypothesis objects containing additional information.
        """
        hypotheses = []
        # Drop predictions to CPU
        predictions = move_dimension_to_the_front(predictions,
                                                  self.batch_dim_index)
        prediction_cpu_tensor = predictions.long().cpu()
        # iterate over batch
        for ind in range(prediction_cpu_tensor.shape[0]):
            prediction = prediction_cpu_tensor[ind].detach().numpy().tolist()
            if predictions_len is not None:
                prediction = prediction[:predictions_len[ind]]
            # CTC decoding procedure
            decoded_prediction = []
            previous = self.blank_id
            for p in prediction:
                if (p != previous
                        or previous == self.blank_id) and p != self.blank_id:
                    decoded_prediction.append(p)
                previous = p

            text = self.decode_tokens_to_str(decoded_prediction)

            if not return_hypotheses:
                hypothesis = text
            else:
                hypothesis = Hypothesis(
                    y_sequence=None,
                    score=-1.0,
                    text=text,
                    alignments=prediction,
                    length=predictions_len[ind]
                    if predictions_len is not None else 0,
                )

            hypotheses.append(hypothesis)
        return hypotheses
Esempio n. 2
0
    def test_RNNTDecoder(self):
        vocab = list(range(10))
        vocab = [str(x) for x in vocab]
        vocab_size = len(vocab)

        pred_config = OmegaConf.create(
            {
                '_target_': 'nemo.collections.asr.modules.RNNTDecoder',
                'prednet': {'pred_hidden': 32, 'pred_rnn_layers': 1,},
                'vocab_size': vocab_size,
                'blank_as_pad': True,
            }
        )

        prednet = modules.RNNTDecoder.from_config_dict(pred_config)

        # num params
        pred_hidden = pred_config.prednet.pred_hidden
        embed = (vocab_size + 1) * pred_hidden  # embedding with blank
        rnn = (
            2 * 4 * (pred_hidden * pred_hidden + pred_hidden)
        )  # (ih + hh) * (ifco gates) * (indim * hiddendim + bias)
        assert prednet.num_weights == (embed + rnn)

        # State initialization
        x_ = torch.zeros(4, dtype=torch.float32)
        states = prednet.initialize_state(x_)

        for state_i in states:
            assert state_i.dtype == x_.dtype
            assert state_i.device == x_.device
            assert state_i.shape[1] == len(x_)

        # Blank hypotheses test
        blank = vocab_size
        hyp = Hypothesis(score=0.0, y_sequence=[blank])
        cache = {}
        pred, states, _ = prednet.score_hypothesis(hyp, cache)

        assert pred.shape == torch.Size([1, 1, pred_hidden])
        assert len(states) == 2
        for state_i in states:
            assert state_i.dtype == pred.dtype
            assert state_i.device == pred.device
            assert state_i.shape[1] == len(pred)

        # Blank stateless predict
        g, states = prednet.predict(y=None, state=None, add_sos=False, batch_size=1)

        assert g.shape == torch.Size([1, 1, pred_hidden])
        assert len(states) == 2
        for state_i in states:
            assert state_i.dtype == g.dtype
            assert state_i.device == g.device
            assert state_i.shape[1] == len(g)

        # Blank stateful predict
        g, states2 = prednet.predict(y=None, state=states, add_sos=False, batch_size=1)

        assert g.shape == torch.Size([1, 1, pred_hidden])
        assert len(states2) == 2
        for state_i, state_j in zip(states, states2):
            assert (state_i - state_j).square().sum().sqrt() > 0.0

        # Predict with token and state
        token = torch.full([1, 1], fill_value=0, dtype=torch.long)
        g, states = prednet.predict(y=token, state=states2, add_sos=False, batch_size=None)

        assert g.shape == torch.Size([1, 1, pred_hidden])
        assert len(states) == 2

        # Predict with blank token and no state
        token = torch.full([1, 1], fill_value=blank, dtype=torch.long)
        g, states = prednet.predict(y=token, state=None, add_sos=False, batch_size=None)

        assert g.shape == torch.Size([1, 1, pred_hidden])
        assert len(states) == 2
Esempio n. 3
0
    def modified_adaptive_expansion_search(
            self, h: torch.Tensor,
            encoded_lengths: torch.Tensor) -> List[Hypothesis]:
        """
        Based on/modified from https://ieeexplore.ieee.org/document/9250505

        Args:
            h: Encoded speech features (1, T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results
        """
        if self.preserve_alignments:
            raise NotImplementedError(
                "`preseve_alignments` is not implemented for Alignment-length Synchronous Decoding."
            )

        h = h[0]  # [T, D]

        # prepare the batched beam states
        beam = min(self.beam_size, self.vocab_size)
        beam_state = self.decoder.initialize_state(
            torch.zeros(beam, device=h.device,
                        dtype=h.dtype))  # [L, B, H], [L, B, H] for LSTMS

        # Initialize first hypothesis for the beam (blank)
        init_tokens = [
            Hypothesis(
                y_sequence=[self.blank],
                score=0.0,
                dec_state=self.decoder.batch_select_state(beam_state, 0),
                timestep=[-1],
                length=0,
            )
        ]

        cache = {}

        # Decode a batch of beam states and scores
        beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(
            init_tokens, cache, beam_state)
        state = self.decoder.batch_select_state(beam_state, 0)

        # TODO: Setup LM
        if self.language_model is not None:
            # beam_lm_states, beam_lm_scores = self.lm.buff_predict(
            #     None, beam_lm_tokens, 1
            # )
            # lm_state = select_lm_state(
            #     beam_lm_states, 0, self.lm_layers, self.is_wordlm
            # )
            # lm_scores = beam_lm_scores[0]
            raise NotImplementedError()
        else:
            lm_state = None
            lm_scores = None

        # Initialize first hypothesis for the beam (blank) for kept hypotheses
        kept_hyps = [
            Hypothesis(
                y_sequence=[self.blank],
                score=0.0,
                dec_state=state,
                dec_out=[beam_dec_out[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for t in range(encoded_lengths):
            enc_out_t = h[t:t + 1].unsqueeze(0)  # [1, 1, D]

            # Perform prefix search to obtain hypothesis
            hyps = self.prefix_search(
                sorted(kept_hyps,
                       key=lambda x: len(x.y_sequence),
                       reverse=True),
                enc_out_t,
                prefix_alpha=self.maes_prefix_alpha,
            )  # type: List[Hypothesis]
            kept_hyps = []

            # Prepare output tensor
            beam_enc_out = enc_out_t

            # List that contains the blank token emisions
            list_b = []

            # Repeat for number of mAES steps
            for n in range(self.maes_num_steps):
                # Pack the decoder logits for all current hypothesis
                beam_dec_out = torch.stack([h.dec_out[-1]
                                            for h in hyps])  # [H, 1, D]

                # Extract the log probabilities
                beam_logp = torch.log_softmax(
                    self.joint.joint(beam_enc_out, beam_dec_out) /
                    self.softmax_temperature,
                    dim=-1,
                )
                beam_logp = beam_logp[:, 0, 0, :]  # [B, V + 1]

                # Compute k expansions for all the current hypotheses
                k_expansions = select_k_expansions(hyps, beam_logp, beam,
                                                   self.maes_expansion_gamma,
                                                   self.maes_expansion_beta)

                # List that contains the hypothesis after prefix expansion
                list_exp = []
                for i, hyp in enumerate(hyps):  # For all hypothesis
                    for k, new_score in k_expansions[
                            i]:  # for all expansion within these hypothesis
                        new_hyp = Hypothesis(
                            y_sequence=hyp.y_sequence[:],
                            score=new_score,
                            dec_out=hyp.dec_out[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        )

                        # If the expansion was for blank
                        if k == self.blank:
                            list_b.append(new_hyp)
                        else:
                            # If the expansion was a token
                            new_hyp.y_sequence.append(int(k))

                            # TODO: Setup LM
                            if self.language_model is not None:
                                # new_hyp.score += self.lm_weight * float(
                                #     hyp.lm_scores[k]
                                # )
                                pass

                            list_exp.append(new_hyp)

                # If there were no token expansions in any of the hypotheses,
                # Early exit
                if not list_exp:
                    kept_hyps = sorted(list_b,
                                       key=lambda x: x.score,
                                       reverse=True)[:beam]

                    break

                else:
                    # Initialize the beam states for the hypotheses in the expannsion list
                    beam_state = self.decoder.batch_initialize_states(
                        beam_state,
                        [hyp.dec_state for hyp in list_exp],
                        # [hyp.y_sequence for hyp in list_exp],  # <look into when this is necessary>
                    )

                    # Decode a batch of beam states and scores
                    beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(
                        list_exp,
                        cache,
                        beam_state,
                        # self.language_model is not None,
                    )

                    # TODO: Setup LM
                    if self.language_model is not None:
                        # beam_lm_states = create_lm_batch_states(
                        #     [hyp.lm_state for hyp in list_exp],
                        #     self.lm_layers,
                        #     self.is_wordlm,
                        # )
                        # beam_lm_states, beam_lm_scores = self.lm.buff_predict(
                        #     beam_lm_states, beam_lm_tokens, len(list_exp)
                        # )
                        pass

                    # If this isnt the last mAES step
                    if n < (self.maes_num_steps - 1):
                        # For all expanded hypothesis
                        for i, hyp in enumerate(list_exp):
                            # Preserve the decoder logits for the current beam
                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.batch_select_state(
                                beam_state, i)

                            # TODO: Setup LM
                            if self.language_model is not None:
                                # hyp.lm_state = select_lm_state(
                                #     beam_lm_states, i, self.lm_layers, self.is_wordlm
                                # )
                                # hyp.lm_scores = beam_lm_scores[i]
                                pass

                        # Copy the expanded hypothesis
                        hyps = list_exp[:]
                    else:
                        # Extract the log probabilities
                        beam_logp = torch.log_softmax(
                            self.joint.joint(beam_enc_out, beam_dec_out) /
                            self.softmax_temperature,
                            dim=-1,
                        )
                        beam_logp = beam_logp[:, 0, 0, :]

                        # For all expansions, add the score for the blank label
                        for i, hyp in enumerate(list_exp):
                            hyp.score += float(beam_logp[i, self.blank])

                            # Preserve the decoder's output and state
                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.batch_select_state(
                                beam_state, i)

                            # TODO: Setup LM
                            if self.language_model is not None:
                                # hyp.lm_state = select_lm_state(
                                #     beam_lm_states, i, self.lm_layers, self.is_wordlm
                                # )
                                # hyp.lm_scores = beam_lm_scores[i]
                                pass

                        # Finally, update the kept hypothesis of sorted top Beam candidates
                        kept_hyps = sorted(list_b + list_exp,
                                           key=lambda x: x.score,
                                           reverse=True)[:beam]

        # Sort the hypothesis with best scores
        return self.sort_nbest(kept_hyps)
Esempio n. 4
0
    def align_length_sync_decoding(
            self, h: torch.Tensor,
            encoded_lengths: torch.Tensor) -> List[Hypothesis]:
        """Alignment-length synchronous beam search implementation.
        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoded speech features (1, T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results
        """
        if self.preserve_alignments:
            raise NotImplementedError(
                "`preseve_alignments` is not implemented for Alignment-length Synchronous Decoding."
            )

        # Precompute some constants for blank position
        ids = list(range(self.vocab_size + 1))
        ids.remove(self.blank)

        # Used when blank token is first vs last token
        if self.blank == 0:
            index_incr = 1
        else:
            index_incr = 0

        # prepare the batched beam states
        beam = min(self.beam_size, self.vocab_size)

        h = h[0]  # [T, D]
        h_length = int(encoded_lengths)
        beam_state = self.decoder.initialize_state(
            torch.zeros(beam, device=h.device,
                        dtype=h.dtype))  # [L, B, H], [L, B, H] for LSTMS

        # compute u_max as either a specific static limit,
        # or a multiple of current `h_length` dynamically.
        if type(self.alsd_max_target_length) == float:
            u_max = int(self.alsd_max_target_length * h_length)
        else:
            u_max = int(self.alsd_max_target_length)

        # Initialize first hypothesis for the beam (blank)
        B = [
            Hypothesis(
                y_sequence=[self.blank],
                score=0.0,
                dec_state=self.decoder.batch_select_state(beam_state, 0),
                timestep=[-1],
                length=0,
            )
        ]

        final = []
        cache = {}

        # ALSD runs for T + U_max steps
        for i in range(h_length + u_max):
            # Update caches
            A = []
            B_ = []
            h_states = []

            # preserve the list of batch indices which are added into the list
            # and those which are removed from the list
            # This is necessary to perform state updates in the correct batch indices later
            batch_ids = list(range(
                len(B)))  # initialize as a list of all batch ids
            batch_removal_ids = []  # update with sample ids which are removed

            for bid, hyp in enumerate(B):
                u = len(hyp.y_sequence) - 1
                t = i - u + 1

                if t > (h_length - 1):
                    batch_removal_ids.append(bid)
                    continue

                B_.append(hyp)
                h_states.append((t, h[t]))

            if B_:
                # Compute the subset of batch ids which were *not* removed from the list above
                sub_batch_ids = None
                if len(B_) != beam:
                    sub_batch_ids = batch_ids
                    for id in batch_removal_ids:
                        # sub_batch_ids contains list of ids *that were not removed*
                        sub_batch_ids.remove(id)

                    # extract the states of the sub batch only.
                    beam_state_ = [
                        beam_state[state_id][:, sub_batch_ids, :]
                        for state_id in range(len(beam_state))
                    ]
                else:
                    # If entire batch was used (none were removed), simply take all the states
                    beam_state_ = beam_state

                # Decode a batch/sub-batch of beam states and scores
                beam_y, beam_state_, beam_lm_tokens = self.decoder.batch_score_hypothesis(
                    B_, cache, beam_state_)

                # If only a subset of batch ids were updated (some were removed)
                if sub_batch_ids is not None:
                    # For each state in the RNN (2 for LSTM)
                    for state_id in range(len(beam_state)):
                        # Update the current batch states with the sub-batch states (in the correct indices)
                        # These indices are specified by sub_batch_ids, the ids of samples which were updated.
                        beam_state[state_id][:,
                                             sub_batch_ids, :] = beam_state_[
                                                 state_id][...]
                else:
                    # If entire batch was updated, simply update all the states
                    beam_state = beam_state_

                # h_states = list of [t, h[t]]
                # so h[1] here is a h[t] of shape [D]
                # Simply stack all of the h[t] within the sub_batch/batch (T <= beam)
                h_enc = torch.stack([h[1] for h in h_states])  # [T=beam, D]
                h_enc = h_enc.unsqueeze(
                    1)  # [B=beam, T=1, D]; batch over the beams

                # Extract the log probabilities and the predicted tokens
                beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y) /
                                              self.softmax_temperature,
                                              dim=-1)  # [B=beam, 1, 1, V + 1]
                beam_logp = beam_logp[:, 0, 0, :]  # [B=beam, V + 1]
                beam_topk = beam_logp[:, ids].topk(beam, dim=-1)

                for j, hyp in enumerate(B_):
                    # For all updated samples in the batch, add it as the blank token
                    # In this step, we dont add a token but simply update score
                    new_hyp = Hypothesis(
                        score=(hyp.score + float(beam_logp[j, self.blank])),
                        y_sequence=hyp.y_sequence[:],
                        dec_state=hyp.dec_state,
                        lm_state=hyp.lm_state,
                        timestep=hyp.timestep[:],
                        length=i,
                    )

                    # Add blank prediction to A
                    A.append(new_hyp)

                    # If the prediction "timestep" t has reached the length of the input sequence
                    # we can add it to the "finished" hypothesis list.
                    if h_states[j][0] == (h_length - 1):
                        final.append(new_hyp)

                    # Here, we carefully select the indices of the states that we want to preserve
                    # for the next token (non-blank) update.
                    if sub_batch_ids is not None:
                        h_states_idx = sub_batch_ids[j]
                    else:
                        h_states_idx = j

                    # for each current hypothesis j
                    # extract the top token score and top token id for the jth hypothesis
                    for logp, k in zip(beam_topk[0][j],
                                       beam_topk[1][j] + index_incr):
                        # create new hypothesis and store in A
                        # Note: This loop does *not* include the blank token!
                        new_hyp = Hypothesis(
                            score=(hyp.score + float(logp)),
                            y_sequence=(hyp.y_sequence[:] + [int(k)]),
                            dec_state=self.decoder.batch_select_state(
                                beam_state, h_states_idx),
                            lm_state=hyp.lm_state,
                            timestep=hyp.timestep[:] + [i],
                            length=i,
                        )

                        A.append(new_hyp)

                # Prune and recombine same hypothesis
                # This may cause next beam to be smaller than max beam size
                # Therefore larger beam sizes may be required for better decoding.
                B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
                B = self.recombine_hypotheses(B)

            # If B_ is empty list, then we may be able to early exit
            elif len(batch_ids) == len(batch_removal_ids):
                break

        if final:
            return self.sort_nbest(final)
        else:
            return B
Esempio n. 5
0
    def time_sync_decoding(self, h: torch.Tensor,
                           encoded_lengths: torch.Tensor) -> List[Hypothesis]:
        """Time synchronous beam search implementation.
        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoded speech features (1, T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results
        """
        if self.preserve_alignments:
            raise NotImplementedError(
                "`preseve_alignments` is not implemented for Time-Synchronous Decoding."
            )

        # Precompute some constants for blank position
        ids = list(range(self.vocab_size + 1))
        ids.remove(self.blank)

        # Used when blank token is first vs last token
        if self.blank == 0:
            index_incr = 1
        else:
            index_incr = 0

        # prepare the batched beam states
        beam = min(self.beam_size, self.vocab_size)
        beam_state = self.decoder.initialize_state(
            torch.zeros(beam, device=h.device,
                        dtype=h.dtype))  # [L, B, H], [L, B, H] (for LSTMs)

        # Initialize first hypothesis for the beam (blank)
        B = [
            Hypothesis(
                y_sequence=[self.blank],
                score=0.0,
                dec_state=self.decoder.batch_select_state(beam_state, 0),
                timestep=[-1],
                length=0,
            )
        ]
        cache = {}

        for i in range(int(encoded_lengths)):
            hi = h[:, i:i + 1, :]

            # Update caches
            A = []
            C = B

            h_enc = hi

            # For a limited number of symmetric expansions per timestep "i"
            for v in range(self.tsd_max_symmetric_expansion_per_step):
                D = []

                # Decode a batch of beam states and scores
                beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score_hypothesis(
                    C, cache, beam_state)

                # Extract the log probabilities and the predicted tokens
                beam_logp = torch.log_softmax(self.joint.joint(h_enc, beam_y) /
                                              self.softmax_temperature,
                                              dim=-1)  # [B, 1, 1, V + 1]
                beam_logp = beam_logp[:, 0, 0, :]  # [B, V + 1]
                beam_topk = beam_logp[:, ids].topk(beam, dim=-1)

                seq_A = [h.y_sequence for h in A]

                for j, hyp in enumerate(C):
                    # create a new hypothesis in A
                    if hyp.y_sequence not in seq_A:
                        # If the sequence is not in seq_A, add it as the blank token
                        # In this step, we dont add a token but simply update score
                        A.append(
                            Hypothesis(
                                score=(hyp.score +
                                       float(beam_logp[j, self.blank])),
                                y_sequence=hyp.y_sequence[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                                timestep=hyp.timestep[:],
                                length=encoded_lengths,
                            ))
                    else:
                        # merge the existing blank hypothesis score with current score.
                        dict_pos = seq_A.index(hyp.y_sequence)

                        A[dict_pos].score = np.logaddexp(
                            A[dict_pos].score,
                            (hyp.score + float(beam_logp[j, self.blank])))

                if v < self.tsd_max_symmetric_expansion_per_step:
                    for j, hyp in enumerate(C):
                        # for each current hypothesis j
                        # extract the top token score and top token id for the jth hypothesis
                        for logp, k in zip(beam_topk[0][j],
                                           beam_topk[1][j] + index_incr):
                            # create new hypothesis and store in D
                            # Note: This loop does *not* include the blank token!
                            new_hyp = Hypothesis(
                                score=(hyp.score + float(logp)),
                                y_sequence=(hyp.y_sequence + [int(k)]),
                                dec_state=self.decoder.batch_select_state(
                                    beam_state, j),
                                lm_state=hyp.lm_state,
                                timestep=hyp.timestep[:] + [i],
                                length=encoded_lengths,
                            )

                            D.append(new_hyp)

                # Prune beam
                C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]

            # Prune beam
            B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]

        return self.sort_nbest(B)
Esempio n. 6
0
    def default_beam_search(self, h: torch.Tensor,
                            encoded_lengths: torch.Tensor) -> List[Hypothesis]:
        """Beam search implementation.

        Args:
            x: Encoded speech features (1, T_max, D_enc)

        Returns:
            nbest_hyps: N-best decoding results
        """
        # Initialize states
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))
        blank_tensor = torch.tensor([self.blank],
                                    device=h.device,
                                    dtype=torch.long)

        # Precompute some constants for blank position
        ids = list(range(self.vocab_size + 1))
        ids.remove(self.blank)

        # Used when blank token is first vs last token
        if self.blank == 0:
            index_incr = 1
        else:
            index_incr = 0

        # Initialize zero vector states
        dec_state = self.decoder.initialize_state(h)

        # Initialize first hypothesis for the beam (blank)
        kept_hyps = [
            Hypothesis(score=0.0,
                       y_sequence=[self.blank],
                       dec_state=dec_state,
                       timestep=[-1],
                       length=0)
        ]
        cache = {}

        if self.preserve_alignments:
            kept_hyps[0].alignments = [[]]

        for i in range(int(encoded_lengths)):
            hi = h[:, i:i + 1, :]  # [1, 1, D]
            hyps = kept_hyps
            kept_hyps = []

            while True:
                max_hyp = max(hyps, key=lambda x: x.score)
                hyps.remove(max_hyp)

                # update decoder state and get next score
                y, state, lm_tokens = self.decoder.score_hypothesis(
                    max_hyp, cache)  # [1, 1, D]

                # get next token
                ytu = torch.log_softmax(self.joint.joint(hi, y) /
                                        self.softmax_temperature,
                                        dim=-1)  # [1, 1, 1, V + 1]
                ytu = ytu[0, 0, 0, :]  # [V + 1]

                # remove blank token before top k
                top_k = ytu[ids].topk(beam_k, dim=-1)

                # Two possible steps - blank token or non-blank token predicted
                ytu = (
                    torch.cat((top_k[0], ytu[self.blank].unsqueeze(0))),
                    torch.cat((top_k[1] + index_incr, blank_tensor)),
                )

                # for each possible step
                for logp, k in zip(*ytu):
                    # construct hypothesis for step
                    new_hyp = Hypothesis(
                        score=(max_hyp.score + float(logp)),
                        y_sequence=max_hyp.y_sequence[:],
                        dec_state=max_hyp.dec_state,
                        lm_state=max_hyp.lm_state,
                        timestep=max_hyp.timestep[:],
                        length=encoded_lengths,
                    )

                    if self.preserve_alignments:
                        new_hyp.alignments = copy.deepcopy(max_hyp.alignments)

                    # if current token is blank, dont update sequence, just store the current hypothesis
                    if k == self.blank:
                        kept_hyps.append(new_hyp)
                    else:
                        # if non-blank token was predicted, update state and sequence and then search more hypothesis
                        new_hyp.dec_state = state
                        new_hyp.y_sequence.append(int(k))
                        new_hyp.timestep.append(i)

                        hyps.append(new_hyp)

                    if self.preserve_alignments:
                        if k == self.blank:
                            new_hyp.alignments[-1].append(self.blank)
                        else:
                            new_hyp.alignments[-1].append(
                                new_hyp.y_sequence[-1])

                # keep those hypothesis that have scores greater than next search generation
                hyps_max = float(max(hyps, key=lambda x: x.score).score)
                kept_most_prob = sorted(
                    [hyp for hyp in kept_hyps if hyp.score > hyps_max],
                    key=lambda x: x.score,
                )

                # If enough hypothesis have scores greater than next search generation,
                # stop beam search.
                if len(kept_most_prob) >= beam:
                    if self.preserve_alignments:
                        # convert Ti-th logits into a torch array
                        for kept_h in kept_most_prob:
                            kept_h.alignments.append(
                                [])  # blank buffer for next timestep

                    kept_hyps = kept_most_prob
                    break

        # Remove trailing empty list of alignments
        if self.preserve_alignments:
            for h in kept_hyps:
                if len(h.alignments[-1]) == 0:
                    del h.alignments[-1]

        return self.sort_nbest(kept_hyps)
Esempio n. 7
0
    def greedy_search(self, h: torch.Tensor,
                      encoded_lengths: torch.Tensor) -> List[Hypothesis]:
        """Greedy search implementation for transducer.
        Generic case when beam size = 1. Results might differ slightly due to implementation details
        as compared to `GreedyRNNTInfer` and `GreedyBatchRNNTInfer`.

        Args:
            h: Encoded speech features (1, T_max, D_enc)

        Returns:
            hyp: 1-best decoding results
        """
        if self.preserve_alignments:
            # Alignments is a 2-dimensional dangling list representing T x U
            alignments = [[]]
        else:
            alignments = None

        # Initialize zero state vectors
        dec_state = self.decoder.initialize_state(h)

        # Construct initial hypothesis
        hyp = Hypothesis(score=0.0,
                         y_sequence=[self.blank],
                         dec_state=dec_state,
                         timestep=[-1],
                         length=encoded_lengths)
        cache = {}

        # Initialize state and first token
        y, state, _ = self.decoder.score_hypothesis(hyp, cache)

        for i in range(int(encoded_lengths)):
            hi = h[:, i:i + 1, :]  # [1, 1, D]

            not_blank = True
            symbols_added = 0

            while not_blank:
                ytu = torch.log_softmax(self.joint.joint(hi, y) /
                                        self.softmax_temperature,
                                        dim=-1)  # [1, 1, 1, V + 1]
                ytu = ytu[0, 0, 0, :]  # [V + 1]

                # max() requires float
                if ytu.dtype != torch.float32:
                    ytu = ytu.float()

                logp, pred = torch.max(ytu, dim=-1)  # [1, 1]
                pred = pred.item()

                if self.preserve_alignments:
                    # insert logits into last timestep
                    alignments[-1].append(pred)

                if pred == self.blank:
                    not_blank = False

                    if self.preserve_alignments:
                        # convert Ti-th logits into a torch array
                        alignments.append([])  # blank buffer for next timestep
                else:
                    # Update state and current sequence
                    hyp.y_sequence.append(int(pred))
                    hyp.score += float(logp)
                    hyp.dec_state = state
                    hyp.timestep.append(i)

                    # Compute next state and token
                    y, state, _ = self.decoder.score_hypothesis(hyp, cache)
                symbols_added += 1

        # Remove trailing empty list of alignments
        if self.preserve_alignments:
            if len(alignments[-1]) == 0:
                del alignments[-1]

        # attach alignments to hypothesis
        hyp.alignments = alignments

        return [hyp]