Пример #1
0
    def __call__(self, audio_signal: torch.Tensor, length: torch.Tensor):
        """Returns a list of hypotheses given an input batch of the encoder hidden embedding.
        Output token is generated auto-repressively.

        Args:
            encoder_output: A tensor of size (batch, features, timesteps).
            encoded_lengths: list of int representing the length of each sequence
                output sequence.

        Returns:
            packed list containing batch number of sentences (Hypotheses).
        """
        with torch.no_grad():
            # Apply optional preprocessing
            encoder_output, encoded_lengths = self.run_encoder(
                audio_signal=audio_signal, length=length)
            encoder_output = encoder_output.transpose([0, 2, 1])  # (B, T, D)
            logitlen = encoded_lengths

            inseq = encoder_output  # [B, T, D]
            hypotheses, timestamps = self._greedy_decode(inseq, logitlen)

            # Pack the hypotheses results
            packed_result = [
                rnnt_utils.Hypothesis(score=-1.0, y_sequence=[])
                for _ in range(len(hypotheses))
            ]
            for i in range(len(packed_result)):
                packed_result[i].y_sequence = torch.tensor(hypotheses[i],
                                                           dtype=torch.long)
                packed_result[i].length = timestamps[i]

            del hypotheses

        return packed_result
Пример #2
0
def pack_hypotheses(
    hypotheses: List[List[int]],
    timesteps: List[List[int]],
    logitlen: torch.Tensor,
    alignments: Optional[List[List[int]]] = None,
) -> List[rnnt_utils.Hypothesis]:
    logitlen_cpu = logitlen.to("cpu")
    return [
        rnnt_utils.Hypothesis(
            y_sequence=torch.tensor(sent, dtype=torch.long),
            score=-1.0,
            timestep=timestep,
            length=length,
            alignments=alignments[idx] if alignments is not None else None,
        ) for idx, (
            sent, timestep,
            length) in enumerate(zip(hypotheses, timesteps, logitlen_cpu))
    ]
Пример #3
0
    def _greedy_decode_masked(
        self,
        x: torch.Tensor,
        out_len: torch.Tensor,
        device: torch.device,
        partial_hypotheses: Optional[List[rnnt_utils.Hypothesis]] = None,
    ):
        if partial_hypotheses is not None:
            raise NotImplementedError(
                "`partial_hypotheses` support is not supported")

        # x: [B, T, D]
        # out_len: [B]
        # device: torch.device

        # Initialize state
        batchsize = x.shape[0]
        hypotheses = [
            rnnt_utils.Hypothesis(score=0.0,
                                  y_sequence=[],
                                  timestep=[],
                                  dec_state=None) for _ in range(batchsize)
        ]

        # Initialize Hidden state matrix (shared by entire batch)
        hidden = None

        # If alignments need to be preserved, register a danling list to hold the values
        if self.preserve_alignments:
            # alignments is a 3-dimensional dangling list representing B x T x U
            for hyp in hypotheses:
                hyp.alignments = [[]]
        else:
            alignments = None

        # Last Label buffer + Last Label without blank buffer
        # batch level equivalent of the last_label
        last_label = torch.full([batchsize, 1],
                                fill_value=self._blank_index,
                                dtype=torch.long,
                                device=device)
        last_label_without_blank = last_label.clone()

        # Mask buffers
        blank_mask = torch.full([batchsize],
                                fill_value=0,
                                dtype=torch.bool,
                                device=device)

        # Get max sequence length
        max_out_len = out_len.max()
        for time_idx in range(max_out_len):
            f = x.narrow(dim=1, start=time_idx, length=1)  # [B, 1, D]

            # Prepare t timestamp batch variables
            not_blank = True
            symbols_added = 0

            # Reset blank mask
            blank_mask.mul_(False)

            # Update blank mask with time mask
            # Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
            # Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
            blank_mask = time_idx >= out_len

            # Start inner loop
            while not_blank and (self.max_symbols is None
                                 or symbols_added < self.max_symbols):
                # Batch prediction and joint network steps
                # If very first prediction step, submit SOS tag (blank) to pred_step.
                # This feeds a zero tensor as input to AbstractRNNTDecoder to prime the state
                if time_idx == 0 and symbols_added == 0 and hidden is None:
                    g, hidden_prime = self._pred_step(self._SOS,
                                                      hidden,
                                                      batch_size=batchsize)
                else:
                    # Set a dummy label for the blank value
                    # This value will be overwritten by "blank" again the last label update below
                    # This is done as vocabulary of prediction network does not contain "blank" token of RNNT
                    last_label_without_blank_mask = last_label == self._blank_index
                    last_label_without_blank[
                        last_label_without_blank_mask] = 0  # temp change of label
                    last_label_without_blank[
                        ~last_label_without_blank_mask] = last_label[
                            ~last_label_without_blank_mask]

                    # Perform batch step prediction of decoder, getting new states and scores ("g")
                    g, hidden_prime = self._pred_step(last_label_without_blank,
                                                      hidden,
                                                      batch_size=batchsize)

                # Batched joint step - Output = [B, V + 1]
                logp = self._joint_step(f, g, log_normalize=None)[:, 0, 0, :]

                if logp.dtype != torch.float32:
                    logp = logp.float()

                # Get index k, of max prob for batch
                v, k = logp.max(1)
                del g

                # Update blank mask with current predicted blanks
                # This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
                k_is_blank = k == self._blank_index
                blank_mask.bitwise_or_(k_is_blank)

                # If preserving alignments, check if sequence length of sample has been reached
                # before adding alignment
                if self.preserve_alignments:
                    # Insert ids into last timestep per sample
                    logp_vals = logp.to('cpu').max(1)[1]
                    for batch_idx in range(batchsize):
                        if time_idx < out_len[batch_idx]:
                            hypotheses[batch_idx].alignments[-1].append(
                                logp_vals[batch_idx])
                    del logp_vals
                del logp

                # If all samples predict / have predicted prior blanks, exit loop early
                # This is equivalent to if single sample predicted k
                if blank_mask.all():
                    not_blank = False

                    # If preserving alignments, convert the current Uj alignments into a torch.Tensor
                    # Then preserve U at current timestep Ti
                    # Finally, forward the timestep history to Ti+1 for that sample
                    # All of this should only be done iff the current time index <= sample-level AM length.
                    # Otherwise ignore and move to next sample / next timestep.
                    if self.preserve_alignments:

                        # convert Ti-th logits into a torch array
                        for batch_idx in range(batchsize):

                            # this checks if current timestep <= sample-level AM length
                            # If current timestep > sample-level AM length, no alignments will be added
                            # Therefore the list of Uj alignments is empty here.
                            if len(hypotheses[batch_idx].alignments[-1]) > 0:
                                hypotheses[batch_idx].alignments.append(
                                    [])  # blank buffer for next timestep
                else:
                    # Collect batch indices where blanks occurred now/past
                    blank_indices = (blank_mask == 1).nonzero(as_tuple=False)

                    # Recover prior state for all samples which predicted blank now/past
                    if hidden is not None:
                        # LSTM has 2 states
                        hidden_prime = self.decoder.batch_copy_states(
                            hidden_prime, hidden, blank_indices)

                    elif len(blank_indices) > 0 and hidden is None:
                        # Reset state if there were some blank and other non-blank predictions in batch
                        # Original state is filled with zeros so we just multiply
                        # LSTM has 2 states
                        hidden_prime = self.decoder.batch_copy_states(
                            hidden_prime, None, blank_indices, value=0.0)

                    # Recover prior predicted label for all samples which predicted blank now/past
                    k[blank_indices] = last_label[blank_indices, 0]

                    # Update new label and hidden state for next iteration
                    last_label = k.view(-1, 1)
                    hidden = hidden_prime

                    # Update predicted labels, accounting for time mask
                    # If blank was predicted even once, now or in the past,
                    # Force the current predicted label to also be blank
                    # This ensures that blanks propogate across all timesteps
                    # once they have occured (normally stopping condition of sample level loop).
                    for kidx, ki in enumerate(k):
                        if blank_mask[kidx] == 0:
                            hypotheses[kidx].y_sequence.append(ki)
                            hypotheses[kidx].timestep.append(time_idx)
                            hypotheses[kidx].score += float(v[kidx])

                symbols_added += 1

        # Remove trailing empty list of alignments at T_{am-len} x Uj
        if self.preserve_alignments:
            for batch_idx in range(batchsize):
                if len(hypotheses[batch_idx].alignments[-1]) == 0:
                    del hypotheses[batch_idx].alignments[-1]

        # Preserve states
        for batch_idx in range(batchsize):
            hypotheses[batch_idx].dec_state = self.decoder.batch_select_state(
                hidden, batch_idx)

        return hypotheses
Пример #4
0
    def _greedy_decode(
            self,
            x: torch.Tensor,
            out_len: torch.Tensor,
            partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None):
        # x: [T, 1, D]
        # out_len: [seq_len]

        # Initialize blank state and empty label set in Hypothesis
        hypothesis = rnnt_utils.Hypothesis(score=0.0,
                                           y_sequence=[],
                                           dec_state=None,
                                           timestep=[])

        if partial_hypotheses is not None:
            if len(partial_hypotheses.y_sequence) > 0:
                hypothesis.y_sequence.append(
                    partial_hypotheses.y_sequence[-1].cpu().numpy())
                hypothesis.dec_state = self.decoder.batch_concat_states(
                    [partial_hypotheses.dec_state])
                hypothesis.dec_state = _states_to_device(
                    hypothesis.dec_state, x.device)

        if self.preserve_alignments:
            # Alignments is a 2-dimensional dangling list representing T x U
            # alignments = [[]]
            hypothesis.alignments = [[]]

        # For timestep t in X_t
        for time_idx in range(out_len):
            # Extract encoder embedding at timestep t
            # f = x[time_idx, :, :].unsqueeze(0)  # [1, 1, D]
            f = x.narrow(dim=0, start=time_idx, length=1)

            # Setup exit flags and counter
            not_blank = True
            symbols_added = 0

            # While blank is not predicted, or we dont run out of max symbols per timestep
            while not_blank and (self.max_symbols is None
                                 or symbols_added < self.max_symbols):
                # In the first timestep, we initialize the network with RNNT Blank
                # In later timesteps, we provide previous predicted label as input.
                last_label = (self._SOS if (hypothesis.y_sequence == []
                                            and hypothesis.dec_state is None)
                              else hypothesis.y_sequence[-1])

                # Perform prediction network and joint network steps.
                g, hidden_prime = self._pred_step(last_label,
                                                  hypothesis.dec_state)
                logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :]

                del g

                # torch.max(0) op doesnt exist for FP 16.
                if logp.dtype != torch.float32:
                    logp = logp.float()

                # get index k, of max prob
                v, k = logp.max(0)
                k = k.item(
                )  # K is the label at timestep t_s in inner loop, s >= 0.

                if self.preserve_alignments:
                    # insert logits into last timestep
                    hypothesis.alignments[-1].append(k)

                del logp

                # If blank token is predicted, exit inner loop, move onto next timestep t
                if k == self._blank_index:
                    not_blank = False

                    if self.preserve_alignments:
                        # convert Ti-th logits into a torch array
                        hypothesis.alignments.append(
                            [])  # blank buffer for next timestep
                else:
                    # Append token to label set, update RNN state.
                    hypothesis.y_sequence.append(k)
                    hypothesis.score += float(v)
                    hypothesis.timestep.append(time_idx)
                    hypothesis.dec_state = hidden_prime

                # Increment token counter.
                symbols_added += 1

        # Remove trailing empty list of Alignments
        if self.preserve_alignments:
            if len(hypothesis.alignments[-1]) == 0:
                del hypothesis.alignments[-1]

        # Unpack the hidden states
        hypothesis.dec_state = self.decoder.batch_select_state(
            hypothesis.dec_state, 0)

        # Remove the original input label if partial hypothesis was provided
        if partial_hypotheses is not None:
            hypothesis.y_sequence = hypothesis.y_sequence[1:]

        return hypothesis