Пример #1
0
    def forward(self, targets, target_length, states=None):
        # y: (B, U)
        y = rnn.label_collate(targets)

        # state maintenance is unnecessary during training forward call
        # to get state, use .predict() method.
        g, _ = self.predict(y, state=states, add_sos=True)  # (B, U, D)
        g = g.transpose(1, 2)  # (B, D, U)

        return g, target_length
Пример #2
0
    def _pred_step(
        self,
        label: Union[torch.Tensor, int],
        hidden: Optional[torch.Tensor],
        add_sos: bool = False,
        batch_size: Optional[int] = None,
    ) -> (torch.Tensor, torch.Tensor):
        """
        Common prediction step based on the AbstractRNNTDecoder implementation.

        Args:
            label: (int/torch.Tensor): Label or "Start-of-Signal" token.
            hidden: (Optional torch.Tensor): RNN State vector
            add_sos (bool): Whether to add a zero vector at the begging as "start of sentence" token.
            batch_size: Batch size of the output tensor.

        Returns:
            g: (B, U, H) if add_sos is false, else (B, U + 1, H)
            hid: (h, c) where h is the final sequence hidden state and c is
                the final cell state:
                    h (tensor), shape (L, B, H)
                    c (tensor), shape (L, B, H)
        """
        if isinstance(label, torch.Tensor):
            # label: [batch, 1]
            if label.dtype != torch.long:
                label = label.long()

        else:
            # Label is an integer
            if label == self._SOS:
                return self.decoder.predict(None,
                                            hidden,
                                            add_sos=add_sos,
                                            batch_size=batch_size)

            label = label_collate([[label]])

        # output: [B, 1, K]
        return self.decoder.predict(label,
                                    hidden,
                                    add_sos=add_sos,
                                    batch_size=batch_size)
Пример #3
0
    def _greedy_decode(
            self,
            x: torch.Tensor,
            out_len: torch.Tensor,
            partial_hypotheses: Optional[rnnt_utils.Hypothesis] = None):
        # x: [T, 1, D]
        # out_len: [seq_len]

        # Initialize blank state and empty label set in Hypothesis
        hypothesis = rnnt_utils.Hypothesis(score=0.0,
                                           y_sequence=[],
                                           dec_state=None,
                                           timestep=[],
                                           last_token=None)

        if partial_hypotheses is not None:
            hypothesis.last_token = partial_hypotheses.last_token
            if partial_hypotheses.dec_state is not None:
                hypothesis.dec_state = self.decoder.batch_concat_states(
                    [partial_hypotheses.dec_state])
                hypothesis.dec_state = _states_to_device(
                    hypothesis.dec_state, x.device)

        if self.preserve_alignments:
            # Alignments is a 2-dimensional dangling list representing T x U
            # alignments = [[]]
            hypothesis.alignments = [[]]

        # For timestep t in X_t
        for time_idx in range(out_len):
            # Extract encoder embedding at timestep t
            # f = x[time_idx, :, :].unsqueeze(0)  # [1, 1, D]
            f = x.narrow(dim=0, start=time_idx, length=1)

            # Setup exit flags and counter
            not_blank = True
            symbols_added = 0

            # While blank is not predicted, or we dont run out of max symbols per timestep
            while not_blank and (self.max_symbols is None
                                 or symbols_added < self.max_symbols):
                # In the first timestep, we initialize the network with RNNT Blank
                # In later timesteps, we provide previous predicted label as input.
                if hypothesis.last_token is None and hypothesis.dec_state is None:
                    last_label = self._SOS
                else:
                    last_label = label_collate([[hypothesis.last_token]])

                # Perform prediction network and joint network steps.
                g, hidden_prime = self._pred_step(last_label,
                                                  hypothesis.dec_state)
                logp = self._joint_step(f, g, log_normalize=None)[0, 0, 0, :]

                del g

                # torch.max(0) op doesnt exist for FP 16.
                if logp.dtype != torch.float32:
                    logp = logp.float()

                # get index k, of max prob
                v, k = logp.max(0)
                k = k.item(
                )  # K is the label at timestep t_s in inner loop, s >= 0.

                if self.preserve_alignments:
                    # insert logits into last timestep
                    hypothesis.alignments[-1].append(k)

                del logp

                # If blank token is predicted, exit inner loop, move onto next timestep t
                if k == self._blank_index:
                    not_blank = False

                    if self.preserve_alignments:
                        # convert Ti-th logits into a torch array
                        hypothesis.alignments.append(
                            [])  # blank buffer for next timestep
                else:
                    # Append token to label set, update RNN state.
                    hypothesis.y_sequence.append(k)
                    hypothesis.score += float(v)
                    hypothesis.timestep.append(time_idx)
                    hypothesis.dec_state = hidden_prime
                    hypothesis.last_token = k

                # Increment token counter.
                symbols_added += 1

        # Remove trailing empty list of Alignments
        if self.preserve_alignments:
            if len(hypothesis.alignments[-1]) == 0:
                del hypothesis.alignments[-1]

        # Unpack the hidden states
        hypothesis.dec_state = self.decoder.batch_select_state(
            hypothesis.dec_state, 0)

        return hypothesis