Ejemplo n.º 1
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        max_target_position,
        return_all_hiddens: bool = False,
        token_embeddings: Optional[torch.Tensor] = None,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).
            token_embeddings (torch.Tensor, optional): precomputed embeddings
                default `None` will recompute embeddings

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
        # assume that we obtain an embedding for the target positions called target_pos (B x T x C)
        # max_target_position = torch.max(target_pos)

        # create position table based of the greatest target position (for now lets just use sentence length of source)
        num_sentences, src_len, d = x.size()
        pos_table = self.constant_positional_encoding[:max_target_position + 1]
        # pos_table: B x T x C
        pos_table = pos_table.repeat(num_sentences, 1).view(num_sentences, max_target_position + 1, -1)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None

        # get the specified list of layers which we want to have position attention occur
        position_layer_list = set(self.positional_layers)
        num_layers = len(position_layer_list)
        probability = torch.empty(size=(num_layers, num_sentences, src_len, max_target_position + 1))
        # encoder layers where the count starts at 1
        for count, layer in enumerate(self.layers, 1):
            if count in position_layer_list:
                reordered_position, pos_attention = position_attention(x, pos_table, 1)
                probability[count-1] = pos_attention
                x = x + reordered_position
            x = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        ), probability
Ejemplo n.º 2
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        cls_input: Optional[Tensor] = None,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        if self.layer_wise_attention:
            return_all_hiddens = True

        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.empty(1).uniform_()
            if not self.training or (dropout_probability > self.encoder_layerdrop):
                x = layer(x, encoder_padding_mask)
                if return_all_hiddens:
                    assert encoder_states is not None
                    encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)
            if return_all_hiddens:
                encoder_states[-1] = x

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
        )
Ejemplo n.º 3
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        if self.conv_layers_before is not None:
            x, src_lengths, encoder_padding_mask = self.conv_layers_before(
                src_tokens, src_lengths)
        else:
            x, encoder_padding_mask = src_tokens, \
                ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1))

        x = self.dropout_module(x)
        if self.fc0 is not None:
            x = self.fc0(x)
            if self.embed_positions is not None:
                # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings
                x = x + self.embed_positions((~encoder_padding_mask).int())
            if self.layernorm_embedding is not None:
                x = self.layernorm_embedding(x)
            x = self.dropout_module(x)
        elif self.embed_positions is not None:
            # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings
            x = x + self.embed_positions((~encoder_padding_mask).int())
            if self.layernorm_embedding is not None:
                x = self.layernorm_embedding(x)

        if not encoder_padding_mask.any():
            encoder_padding_mask = None

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        attn_mask = self.get_attn_mask(src_lengths)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask, attn_mask=attn_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=None,
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
        )
    def forward(self,
                src_tokens,
                src_lengths,
                return_all_hiddens: bool = False):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None
        self_attn_at_list = []
        # encoder layers
        for layer in self.layers:
            x, self_attn_at = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)
            self_attn_at_list.append(self_attn_at.transpose(1, 0).contiguous())

        self_attn_at_tensor = None
        if self_attn_at_list is not None:
            self_attn_at_tensor = torch.stack(
                self_attn_at_list,
                dim=0)  # (layers, batch, heads, tgt_len, src_len)
            self_attn_at_tensor = self_attn_at_tensor.transpose(
                0, 1).contiguous()

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None,
            self_attn_at_tensor=self_attn_at_tensor,
        )
Ejemplo n.º 5
0
    def forward(
        self,
        src_tokens: Tensor,
        src_lengths: Tensor,
        enforce_sorted: bool = True,
        **unused,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of
                shape `(batch, src_len)`
            src_lengths (LongTensor): lengths of each source sentence of
                shape `(batch)`
            enforce_sorted (bool, optional): if True, `src_tokens` is
                expected to contain sequences sorted by length in a
                decreasing order. If False, this condition is not
                required. Default: True.
        """
        if self.left_pad:
            # nn.utils.rnn.pack_padded_sequence requires right-padding;
            # convert left-padding to right-padding
            src_tokens = speech_utils.convert_padding_direction(
                src_tokens,
                src_lengths,
                left_to_right=True,
            )

        if self.conv_layers_before is not None:
            x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths)
        else:
            x, padding_mask = src_tokens, \
                ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1))

        bsz, seqlen = x.size(0), x.size(1)

        x = F.dropout(x, p=self.dropout_in, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        state_size = 2 if self.bidirectional else 1, bsz, self.hidden_size
        h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size)

        for i in range(len(self.lstm)):
            if self.residual and i > 0:  # residual connection starts from the 2nd layer
                prev_x = x
            # pack embedded source tokens into a PackedSequence
            packed_x = nn.utils.rnn.pack_padded_sequence(
                x, src_lengths.data, enforce_sorted=enforce_sorted
            )

            # apply LSTM
            packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0))

            # unpack outputs and apply dropout
            x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0)
            if i < len(self.lstm) - 1:  # not applying dropout for the last layer
                x = F.dropout(x, p=self.dropout_out, training=self.training)
            x = x + prev_x if self.residual and i > 0 else x
        assert list(x.size()) == [seqlen, bsz, self.output_units]

        encoder_padding_mask = padding_mask.t()

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None,  # T x B
            encoder_embedding=None,
            encoder_states=None,
            src_tokens=None,
            src_lengths=src_lengths,  # B
        )
    def forward(self,
                src_video,
                src_lengths=None,
                return_all_hiddens: bool = False,
                **kwargs):
        """

        :param src_video: [batch, num_frames, channels, width, height]
        :param src_lengths:
        :param kwargs:
        :return:
        """
        # cnn module
        bs, num_fm, c, h, w = src_video.size()
        src = src_video.view(bs * num_fm, c, h, w)
        spatiol_feature = self.spatio_enc(
            src).squeeze()  # [bs * num_fm, embed_dim]

        if self.args.cnn_normalize_after:
            spatiol_feature = self.batchnorm(spatiol_feature)
            spatiol_feature = self.relu(spatiol_feature)
        spatiol_feature = spatiol_feature.view(
            bs, num_fm, spatiol_feature.size(-1))  # [bs, num_fm, embed_dim]

        position_tensor = torch.LongTensor(list(
            range(num_fm))).unsqueeze_(0).repeat(bs, 1).type_as(src_lengths)
        position_tensor = position_tensor.le(
            src_lengths.unsqueeze(1))  # padding 部分为 0
        # add position encoding
        if self.embed_positions is not None:
            x = spatiol_feature + self.embed_positions(
                position_tensor)  # # [bs, num_fm, embed_dim]
        else:
            x = spatiol_feature

        if self.layernorm_embedding is not None:
            x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = x.transpose(0, 1)  # [num_fm, bs, embed_dim]
        encoder_padding_mask = position_tensor.eq(self.padding_idx)

        if self.layer_wise_attention:
            return_all_hiddens = True

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.temporal_enc_layers:
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = torch.empty(1).uniform_()
            if not self.training or (dropout_probability >
                                     self.encoder_layerdrop):
                x = layer(x, encoder_padding_mask)
                if return_all_hiddens:
                    assert encoder_states is not None
                    encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)
            if return_all_hiddens:
                encoder_states[-1] = x

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=spatiol_feature,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
        )
Ejemplo n.º 7
0
    def forward(self, model, sample, reduction="sum", log_probs=True):
        encoder_output = model.encoder(tbc=False, **sample["net_input"])
        ctc_logits = encoder_output['encoder_out']
        len_ctc_logits = (
            ~encoder_output['encoder_padding_mask']).long().sum(-1)
        encoder_output = EncoderOut(
            encoder_out=encoder_output['encoded'].transpose(0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_output[
                'encoder_padding_mask'],  # B x T
            encoder_states=None,
            src_tokens=None,
            src_lengths=None,
        )

        p = max((model.num_updates - model.teacher_forcing_updates) / 2000.0,
                0.0)
        if model.num_updates <= model.teacher_forcing_updates:
            decoder_out = model.decoder(
                prev_output_tokens=sample["net_input"]["prev_output_tokens"],
                encoder_out=encoder_output)
        else:
            with torch.no_grad():
                decoder_out = model.decoder(
                    prev_output_tokens=sample["net_input"]
                    ["prev_output_tokens"],
                    encoder_out=encoder_output)
                decoded = decoder_out["logits"].argmax(-1).int()
                device = decoded.device
                prev_self_deocded = torch.cat([
                    torch.ones([decoded.size(0), 1]).int().to(device) *
                    self.task.target_dictionary.eos(), decoded[:, :-1]
                ], 1)
                prev_output = torch.where(
                    (torch.rand(decoded.size()) > p).to(device),
                    sample["net_input"]["prev_output_tokens"],
                    prev_self_deocded)
            decoder_out = model.decoder(prev_output_tokens=prev_output,
                                        encoder_out=encoder_output)

        target = sample["target"]
        target_lengths = sample["target_lengths"]
        lprobs, ctc_loss, ce_loss = self.compute_loss(model, ctc_logits,
                                                      len_ctc_logits,
                                                      decoder_out["logits"],
                                                      target, target_lengths,
                                                      reduction, log_probs)
        sample_size, logging_output = self.get_logging_output(
            sample, target, lprobs, ctc_loss, ce_loss)
        loss = ctc_loss + ce_loss
        logging_output['schedule_sampling'] = p

        if not model.training:
            import editdistance

            c_err = 0
            c_len = 0
            self.decoder.step_forward_fn = model.decoder
            input_lengths = (~encoder_output.encoder_padding_mask).sum(-1)
            with torch.no_grad():
                decodeds = self.decoder.decode(encoder_output, 50)
                for decoded, t, inp_l in zip(decodeds, sample["target"],
                                             input_lengths):
                    decoded = decoded[0]['tokens']

                    p = (t != self.task.target_dictionary.pad()) & (
                        t != self.task.target_dictionary.eos())
                    targ = t[p]
                    targ_units_arr = targ.tolist()
                    pred_units_arr = decoded.tolist()

                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)

                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output
Ejemplo n.º 8
0
    def forward(
        self,
        src_tokens,
        src_lengths,
        cls_input: Optional[Tensor] = None,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            src_tokens (LongTensor): tokens in the source language of shape
                `(batch, src_len)`
            src_lengths (torch.LongTensor): lengths of each source sentence of
                shape `(batch)`
            return_all_hiddens (bool, optional): also return all of the
                intermediate hidden states (default: False).

        Returns:
            namedtuple:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
                - **encoder_embedding** (Tensor): the (scaled) embedding lookup
                  of shape `(batch, src_len, embed_dim)`
                - **encoder_states** (List[Tensor]): all intermediate
                  hidden states of shape `(src_len, batch, embed_dim)`.
                  Only populated if *return_all_hiddens* is True.
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)
        x = x.transpose(0, 1)  # B x T x C -> T x B x C

        # if not return_all hiddens, encoder states are expected to be an empty list
        # and we do not support encoder hiddens, but need to satisfy the interface
        encoder_states = []

        # U-Net part:
        x_unet = x
        encoder_padding_mask = src_tokens.eq(self.padding_idx)
        x_unet = self.forward_unet(x_unet, encoder_padding_mask)

        # Transformer part:
        x_transformer = x
        for layer in self.transformer_layers:
            x_transformer = layer(x_transformer, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x_transformer)

        # Combine U-Net representations and Transformer representations
        x, _ = torch.stack([x_transformer, x_unet], dim=0).max(0)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
            src_tokens=src_tokens,
            src_lengths=src_lengths,
        )
Ejemplo n.º 9
0
    def batch_beam_decode(encoder_output,
                          step_forward_fn,
                          incremental_state,
                          SOS_ID,
                          EOS_ID,
                          vocab_size,
                          beam_size=1,
                          max_decode_len=100):
        """
        encoder_output:
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
        """
        encoded = encoder_output.encoder_out  # T x B x C
        len_encoded = (~encoder_output.encoder_padding_mask).sum(-1)
        batch_size = len_encoded.size(0)
        device = encoded.device
        d_output = vocab_size

        # beam search Initialize
        # repeat each sample in batch along the batch axis [1,2,3,4] -> [1,1,2,2,3,3,4,4]
        encoded = encoded[:, None, :, :].repeat(
            1, beam_size, 1, 1)  # [batch_size, beam_size, *, hidden_units]
        encoded = encoded.view(batch_size * beam_size, -1, encoded.size(-1))
        len_encoded = len_encoded[:, None].repeat(1, beam_size).view(
            -1)  # [batch_size * beam_size]
        encoder_padding_mask = encoder_output.encoder_padding_mask.repeat(
            1, beam_size).reshape(batch_size * beam_size, -1)

        encoder_output = EncoderOut(
            encoder_out=encoded.transpose(0, 1),  # T x B x C
            encoder_embedding=None,
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_states=None,  # List[T x B x C]
            src_tokens=None,
            src_lengths=None)

        # [[<S>, <S>, ..., <S>]], shape: [batch_size * beam_size, 1]
        preds = torch.ones([batch_size * beam_size, 1
                            ]).long().to(device) * SOS_ID
        logits = torch.zeros([batch_size * beam_size, 0,
                              d_output]).float().to(device)
        len_decoded = torch.ones_like(len_encoded)
        # the score must be [0, -inf, -inf, ...] at init, for the preds in beam is same in init!!!
        scores = torch.tensor([0.0] + [-inf] *
                              (beam_size - 1)).float().repeat(batch_size).to(
                                  device)  # [batch_size * beam_size]
        finished = torch.zeros_like(scores).bool().to(device)

        # collect the initial states of lstms used in decoder.
        base_indices = torch.arange(batch_size)[:, None].repeat(
            1, beam_size).view(-1).to(device)

        for _ in range(max_decode_len):
            # i, preds, scores, logits, len_decoded, finished
            decoder_output = step_forward_fn(
                prev_output_tokens=preds,
                encoder_out=encoder_output,
                incremental_state=incremental_state)

            cur_logits = decoder_output["logits"]

            logits = torch.cat([logits, cur_logits],
                               1)  # [batch*beam, t, size_output]
            z = F.log_softmax(cur_logits[:, -1, :],
                              dim=-1)  # [batch*beam, size_output]

            # rank the combined scores
            next_scores, next_preds = torch.topk(z,
                                                 k=beam_size,
                                                 sorted=True,
                                                 dim=-1)

            # beamed scores & Pruning
            scores = scores[:,
                            None] + next_scores  # [batch_size * beam_size, beam_size]
            scores = scores.view(batch_size, beam_size * beam_size)

            _, k_indices = torch.topk(scores, k=beam_size)
            k_indices = base_indices * beam_size * beam_size + k_indices.view(
                -1)  # [batch_size * beam_size]
            # Update scores.
            scores = scores.view(-1)[k_indices]
            # Update predictions.
            next_preds = next_preds.view(-1)[k_indices]

            # k_indices: [0~batch*beam*beam], preds: [0~batch*beam]
            # preds, cache_lm, cache_decoder: these data are shared during the beam expand among vocab
            preds = preds[k_indices // beam_size]
            preds = torch.cat([preds, next_preds[:, None]],
                              axis=1)  # [batch_size * beam_size, i]

            has_eos = next_preds.eq(EOS_ID)
            finished = torch.logical_or(finished, has_eos)
            len_decoded += 1 - finished.int()

            if finished.int().sum() == finished.size(0):
                break

        len_decoded -= 1 - finished.int(
        )  # for decoded length cut by encoded length
        preds = preds[:, 1:]
        # tf.nn.top_k is used to sort `scores`
        scores_sorted, sorted = torch.topk(scores.view(batch_size, beam_size),
                                           k=beam_size,
                                           sorted=True)
        sorted = base_indices * beam_size + sorted.view(
            -1)  # [batch_size * beam_size]

        # [batch_size * beam_size, ...] -> [batch_size, beam_size, ...]
        preds_sorted = preds[sorted].view(
            batch_size, beam_size, -1)  # [batch_size, beam_size, max_length]
        len_decoded_sorted = len_decoded[sorted].view(batch_size, beam_size)
        scores_sorted = scores[sorted].view(batch_size, beam_size)

        return preds_sorted, len_decoded_sorted, scores_sorted