Example #1
0
    def forward(
        self,
        xs_pad: torch.Tensor,
        xs_lens: torch.Tensor,
        decoding_chunk_size: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Embed positions in tensor.

        Args:
            xs_pad: padded input tensor (B, L, D)
            xs_lens: input length (B)
            decoding_chunk_size: decoding chunk size for dynamic chunk, it's
                0: default for training, use random dynamic chunk.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
        Returns:
            encoder output tensor, lens and mask
        """
        masks = ~make_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, L)
        xs, pos_emb, masks = self.embed(xs_pad, masks)
        chunk_masks = add_optional_chunk_mask(xs, masks,
                                              self.use_dynamic_chunk,
                                              decoding_chunk_size,
                                              self.static_chunk_size)
        for layer in self.encoders:
            xs, chunk_masks = layer(xs, chunk_masks, pos_emb)
        if self.normalize_before:
            xs = self.after_norm(xs)
        # Here we assume the mask is not changed in encoder layers, so just
        # return the masks before encoder layers, and the masks will be used
        # for cross attention with decoder later
        return xs, masks
Example #2
0
    def forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: Optional[torch.Tensor] = None,
        decoding_chunk_size: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequence.
        Args:
            xs (torch.Tensor): Input tensor (#batch, time, idim).
            masks (torch.Tensor): Mask tensor (#batch, time).
        Returns:
            torch.Tensor: Output tensor (#batch, time, attention_dim).
            torch.Tensor: Mask tensor (#batch, time).
        """
        masks = ~make_pad_mask(ilens).unsqueeze(1)
        # the returned xs by self.embed is in: tuple(x, position_encoding)
        xs, masks = self.embed(xs, masks)
        chunk_masks = add_optional_chunk_mask(xs[0], masks,
                                              self.use_dynamic_chunk,
                                              decoding_chunk_size,
                                              self.static_chunk_size)
        for layer in self.encoders:
            xs, chunk_masks = layer(xs, chunk_masks)
        if isinstance(xs, tuple):
            xs = xs[0]
        if self.normalize_before:
            xs = self.after_norm(xs)

        return xs, masks
Example #3
0
 def ctc_greedy_search(self,
                       speech: torch.Tensor,
                       speech_lengths: torch.Tensor,
                       decoding_chunk_size: int = -1) -> List[List[int]]:
     '''
     param: speech: (batch, max_len, feat_dim)
     param: speech_length: (batch, )
     param: decoding_chunk_size
             <0: for decoding, use full chunk.
             >0: for decoding, use fixed chunk size as set.
             0: used for training, it's prohibited here
     return:
         best path result, without remove blank and duplicates
     '''
     assert speech.shape[0] == speech_lengths.shape[0]
     assert decoding_chunk_size != 0
     device = speech.device
     batch_size = speech.shape[0]
     # Let's assume B = batch_size
     encoder_out, encoder_mask = self.encoder(
         speech, speech_lengths, decoding_chunk_size=decoding_chunk_size
     )  # (B, maxlen, encoder_dim)
     maxlen = encoder_out.size(1)
     encoder_out_lens = encoder_mask.squeeze(1).sum(1)
     ctc_probs = self.ctc.log_softmax(
         encoder_out)  # (B, maxlen, vocab_size)
     topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
     mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
     topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
     hyps = [hyp.tolist() for hyp in topk_index]
     hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
     return hyps
Example #4
0
    def ctc_greedy_search(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        decoding_chunk_size: int = -1,
        num_decoding_left_chunks: int = -1,
        simulate_streaming: bool = False,
    ) -> List[List[int]]:
        """ Apply CTC greedy search

        Args:
            speech (torch.Tensor): (batch, max_len, feat_dim)
            speech_length (torch.Tensor): (batch, )
            beam_size (int): beam size for beam search
            decoding_chunk_size (int): decoding chunk for dynamic chunk
                trained model.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
                0: used for training, it's prohibited here
            simulate_streaming (bool): whether do encoder forward in a
                streaming fashion
        Returns:
            List[List[int]]: best path result
        """
        assert speech.shape[0] == speech_lengths.shape[0]
        assert decoding_chunk_size != 0
        batch_size = speech.shape[0]
        #print("speech shape:",speech.shape,"speech_lengths:",speech_lengths)
        # Let's assume B = batch_size
        encoder_out, encoder_mask = self._forward_encoder(
            speech, speech_lengths, decoding_chunk_size,
            num_decoding_left_chunks,
            simulate_streaming)  # (B, maxlen, encoder_dim)
        maxlen = encoder_out.size(1)
        #print("maxlen:",maxlen)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        #print("encoder_out_lens:",encoder_out_lens)
        ctc_probs = self.ctc.log_softmax(
            encoder_out)  # (B, maxlen, vocab_size)
        topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
        ##print("topk_index:",topk_index.shape,"topk_prob:",topk_prob.shape)
        topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
        #print("topk_index:",topk_index)
        mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
        #print("mask:",mask)
        topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
        hyps = [hyp.tolist() for hyp in topk_index]
        hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
        return hyps
Example #5
0
    def forward(
        self,
        xs: torch.Tensor,
        xs_lens: torch.Tensor,
        decoding_chunk_size: int = 0,
        num_decoding_left_chunks: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Embed positions in tensor.

        Args:
            xs: padded input tensor (B, T, D)
            xs_lens: input length (B)
            decoding_chunk_size: decoding chunk size for dynamic chunk
                0: default for training, use random dynamic chunk.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
            num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
                >=0: use num_decoding_left_chunks
                <0: use all left chunks
        Returns:
            encoder output tensor xs, and subsampled masks
            xs: padded output tensor (B, T' ~= T/subsample_rate, D)
            masks: torch.Tensor batch padding mask after subsample
                (B, 1, T' ~= T/subsample_rate)
        """
        T = xs.size(1)
        masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
        if self.global_cmvn is not None:
            xs = self.global_cmvn(xs)
        xs, pos_emb, masks = self.embed(xs, masks)
        mask_pad = masks  # (B, 1, T/subsample_rate)
        chunk_masks = add_optional_chunk_mask(xs, masks,
                                              self.use_dynamic_chunk,
                                              self.use_dynamic_left_chunk,
                                              decoding_chunk_size,
                                              self.static_chunk_size,
                                              num_decoding_left_chunks)
        for layer in self.encoders:
            xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
        if self.normalize_before:
            xs = self.after_norm(xs)
        # Here we assume the mask is not changed in encoder layers, so just
        # return the masks before encoder layers, and the masks will be used
        # for cross attention with decoder later
        return xs, masks
Example #6
0
 def forward(
     self,
     memory: torch.Tensor,
     memory_mask: torch.Tensor,
     ys_in_pad: torch.Tensor,
     ys_in_lens: torch.Tensor,
     r_ys_in_pad: torch.Tensor = torch.empty(0),
     reverse_weight: float = 0.0,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """Forward decoder.
     Args:
         memory: encoded memory, float32  (batch, maxlen_in, feat)
         memory_mask: encoder memory mask, (batch, 1, maxlen_in)
         ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
         ys_in_lens: input lengths of this batch (batch)
         r_ys_in_pad: not used in transformer decoder, in order to unify api
             with bidirectional decoder
         reverse_weight: not used in transformer decoder, in order to unify
             api with bidirectional decode
     Returns:
         (tuple): tuple containing:
             x: decoded token score before softmax (batch, maxlen_out,
                 vocab_size) if use_output_layer is True,
             torch.tensor(0.0), in order to unify api with bidirectional decoder
             olens: (batch, )
     """
     tgt = ys_in_pad
     maxlen = tgt.size(1)
     # tgt_mask: (B, 1, L)
     tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
     tgt_mask = tgt_mask.to(tgt.device)
     # m: (1, L, L)
     m = subsequent_mask(tgt_mask.size(-1),
                         device=tgt_mask.device).unsqueeze(0)
     # tgt_mask: (B, L, L)
     tgt_mask = tgt_mask & m
     x, _ = self.embed(tgt)
     for layer in self.decoders:
         x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
                                                  memory_mask)
     if self.normalize_before:
         x = self.after_norm(x)
     if self.use_output_layer:
         x = self.output_layer(x)
     olens = tgt_mask.sum(1)
     return x, torch.tensor(0.0), olens
Example #7
0
    def forward(
        self,
        memory: torch.Tensor,
        memory_mask: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.

        Args:
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            memory_mask: encoder memory mask, (batch, 1, maxlen_in)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:

            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        # tgt_mask: (B, 1, L)
        tgt_mask = (~make_pad_mask(ys_in_lens).unsqueeze(1)).to(tgt.device)
        # m: (1, L, L)
        m = subsequent_mask(tgt_mask.size(-1),
                            device=tgt_mask.device).unsqueeze(0)
        # tgt_mask: (B, L, L)
        tgt_mask = tgt_mask & m

        x, _ = self.embed(tgt)
        for layer in self.decoders:
            x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
                                                     memory_mask)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.use_output_layer:
            x = self.output_layer(x)

        olens = tgt_mask.sum(1)
        return x, olens
Example #8
0
    def forward(
        self,
        xs: torch.Tensor,
        xs_lens: torch.Tensor,
        decoding_chunk_size: int = 0,
        num_decoding_left_chunks: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Embed positions in tensor.

        Args:
            xs: padded input tensor (B, L, D)
            xs_lens: input length (B)
            decoding_chunk_size: decoding chunk size for dynamic chunk
                0: default for training, use random dynamic chunk.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
            num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
                >=0: use num_decoding_left_chunks
                <0: use all left chunks
        Returns:
            encoder output tensor, lens and mask
        """
        masks = ~make_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, L) 根据batch填充来制作pad,因为batch内的有效长度并不相同.
        if self.global_cmvn is not None:
            xs = self.global_cmvn(xs)
        xs, pos_emb, masks = self.embed(xs, masks) #Conv2dSubsampling4和RelPositionalEncoding存在下采样
        mask_pad = masks
        chunk_masks = add_optional_chunk_mask(xs, masks,
                                              self.use_dynamic_chunk,
                                              self.use_dynamic_left_chunk,
                                              decoding_chunk_size,
                                              self.static_chunk_size,
                                              num_decoding_left_chunks)
        for layer in self.encoders:
            xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
        if self.normalize_before:
            xs = self.after_norm(xs)
        # Here we assume the mask is not changed in encoder layers, so just
        # return the masks before encoder layers, and the masks will be used
        # for cross attention with decoder later
        return xs, masks
Example #9
0
 def forward(self,
             encoder_out: torch.Tensor,
             encoder_lens: torch.Tensor,
             hyps_pad_sos: torch.Tensor,
             hyps_lens: torch.Tensor,
             r_hyps_pad_sos: torch.Tensor):
     """Encoder
     Args:
         encoder_out: B x T x F
         encoder_lens: B
         hyps_pad_sos: B x beam x T2,
                     hyps with sos and padded by ignore id
         hyps_lens: B x beam, length for each hyp with sos
         r_hyps_pad_sos: B x beam x T2,
                 reversed hyps with sos and padded by ignore id
     Returns:
         decoder_out: B x beam x T2 x V
         r_decoder_out: B x beam x T2 x V
     """
     B, T, F = encoder_out.shape
     bz = self.beam_size
     B2 = B * bz
     encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
     encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
     encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
     T2 = hyps_pad_sos.shape[2]
     hyps_pad = hyps_pad_sos.view(B2, T2)
     hyps_lens = hyps_lens.view(B2,)
     r_hyps_pad = r_hyps_pad_sos.view(B2, T2)
     decoder_out, r_decoder_out, _ = self.decoder(
         encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
         self.reverse_weight)
     decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
     V = decoder_out.shape[-1]
     decoder_out = decoder_out.view(B, bz, T2, V)
     r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
     r_decoder_out = r_decoder_out.view(B, bz, T2, V)
     return decoder_out, r_decoder_out
Example #10
0
    def ctc_greedy_search(self,
                          speech: torch.Tensor,
                          speech_lengths: torch.Tensor,
                          decoding_chunk_size: int = -1) -> List[List[int]]:
        """ Apply CTC greedy search

        Args:
            speech (torch.Tensor): (batch, max_len, feat_dim)
            speech_length (torch.Tensor): (batch, )
            beam_size (int): beam size for beam search
            decoding_chunk_size (int): decoding chunk for dynamic chunk
                trained model.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
                0: used for training, it's prohibited here

        Returns:
            List[List[int]]: best path result
        """
        assert speech.shape[0] == speech_lengths.shape[0]
        assert decoding_chunk_size != 0
        batch_size = speech.shape[0]
        # Let's assume B = batch_size
        encoder_out, encoder_mask = self.encoder(
            speech, speech_lengths, decoding_chunk_size=decoding_chunk_size
        )  # (B, maxlen, encoder_dim)
        maxlen = encoder_out.size(1)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        ctc_probs = self.ctc.log_softmax(
            encoder_out)  # (B, maxlen, vocab_size)
        topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
        topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
        mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
        topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
        hyps = [hyp.tolist() for hyp in topk_index]
        hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
        return hyps
Example #11
0
    def forward(self, encoder_out: torch.Tensor, encoder_lens: torch.Tensor,
                hyps_pad_sos_eos: torch.Tensor, hyps_lens_sos: torch.Tensor,
                r_hyps_pad_sos_eos: torch.Tensor, ctc_score: torch.Tensor):
        """Encoder
        Args:
            encoder_out: B x T x F
            encoder_lens: B
            hyps_pad_sos_eos: B x beam x (T2+1),
                        hyps with sos & eos and padded by ignore id
            hyps_lens_sos: B x beam, length for each hyp with sos
            r_hyps_pad_sos_eos: B x beam x (T2+1),
                    reversed hyps with sos & eos and padded by ignore id
            ctc_score: B x beam, ctc score for each hyp
        Returns:
            decoder_out: B x beam x T2 x V
            r_decoder_out: B x beam x T2 x V
            best_index: B
        """
        B, T, F = encoder_out.shape
        bz = self.beam_size
        B2 = B * bz
        encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
        encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
        encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
        T2 = hyps_pad_sos_eos.shape[2] - 1
        hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
        hyps_lens = hyps_lens_sos.view(B2, )
        hyps_pad_sos = hyps_pad[:, :-1].contiguous()
        hyps_pad_eos = hyps_pad[:, 1:].contiguous()

        r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
        r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
        r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()

        decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
                                                     hyps_pad_sos, hyps_lens,
                                                     r_hyps_pad_sos,
                                                     self.reverse_weight)
        decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
        V = decoder_out.shape[-1]
        decoder_out = decoder_out.view(B2, T2, V)
        mask = ~make_pad_mask(hyps_lens, T2)  # B2 x T2
        # mask index, remove ignore id
        index = torch.unsqueeze(hyps_pad_eos * mask, 2)
        score = decoder_out.gather(2, index).squeeze(2)  # B2 X T2
        # mask padded part
        score = score * mask
        decoder_out = decoder_out.view(B, bz, T2, V)
        if self.reverse_weight > 0:
            r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out,
                                                            dim=-1)
            r_decoder_out = r_decoder_out.view(B2, T2, V)
            index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
            r_score = r_decoder_out.gather(2, index).squeeze(2)
            r_score = r_score * mask
            score = score * (
                1 - self.reverse_weight) + self.reverse_weight * r_score
            r_decoder_out = r_decoder_out.view(B, bz, T2, V)
        score = torch.sum(score, axis=1)  # B2
        score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
        best_index = torch.argmax(score, dim=1)
        return best_index
Example #12
0
 feat = feat.to(device)
 target = target.to(device)
 feats_length = feats_length.to(device)
 target_length = target_length.to(device)
 # Let's assume B = batch_size and N = beam_size
 # 1. Encoder
 encoder_out, encoder_mask = model._forward_encoder(
     feat, feats_length)  # (B, maxlen, encoder_dim)
 maxlen = encoder_out.size(1)
 batch_size = encoder_out.size(0)
 ctc_probs = model.ctc.log_softmax(
     encoder_out)  # (1, maxlen, vocab_size)
 encoder_out_lens = encoder_mask.squeeze(1).sum(1)
 topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
 topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
 mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
 topk_index = topk_index.masked_fill_(mask, eos)  # (B, maxlen)
 alignment = [hyp.tolist() for hyp in topk_index]
 hyps = [remove_duplicates_and_blank(hyp) for hyp in alignment]
 for index, i in enumerate(key):
     content = []
     if len(hyps[index]) > 0:
         for w in hyps[index]:
             if w == eos:
                 break
             content.append(char_dict[w])
     f_ctc_results.write('{} {}\n'.format(i, " ".join(content)))
 f_ctc_results.flush()
 for index, i in enumerate(key):
     timestamp = get_frames_timestamp(alignment[index])
     subsample = get_subsample(configs)
Example #13
0
    def forward(self, chunk_xs, chunk_lens, offset,
                att_cache, cnn_cache, cache_mask):
        """Streaming Encoder
        Args:
            xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
                where `time == (chunk_size - 1) * subsample_rate + \
                        subsample.right_context + 1`
            offset (torch.Tensor): offset with shape (b, 1)
                        1 is retained for triton deployment
            required_cache_size (int): cache size required for next chunk
                compuation
                > 0: actual cache size
                <= 0: not allowed in streaming gpu encoder                   `
            att_cache (torch.Tensor): cache tensor for KEY & VALUE in
                transformer/conformer attention, with shape
                (b, elayers, head, cache_t1, d_k * 2), where
                `head * d_k == hidden-dim` and
                `cache_t1 == chunk_size * num_decoding_left_chunks`.
            cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
                (b, elayers, b, hidden-dim, cache_t2), where
                `cache_t2 == cnn.lorder - 1`
            cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
                 in a batch of request, each request may have different
                 history cache. Cache mask is used to indidate the effective
                 cache for each request
        Returns:
            torch.Tensor: log probabilities of ctc output and cutoff by beam size
                with shape (b, chunk_size, beam)
            torch.Tensor: index of top beam size probabilities for each timestep
                with shape (b, chunk_size, beam)
            torch.Tensor: output of current input xs,
                with shape (b, chunk_size, hidden-dim).
            torch.Tensor: new attention cache required for next chunk, with
                same shape (b, elayers, head, cache_t1, d_k * 2)
                as the original att_cache
            torch.Tensor: new conformer cnn cache required for next chunk, with
                same shape as the original cnn_cache.
            torch.Tensor: new cache mask, with same shape as the original
                cache mask
        """
        offset = offset.squeeze(1)
        T = chunk_xs.size(1)
        chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
        # B X 1 X T
        chunk_mask = chunk_mask.to(chunk_xs.dtype)
        # transpose batch & num_layers dim
        att_cache = torch.transpose(att_cache, 0, 1)
        cnn_cache = torch.transpose(cnn_cache, 0, 1)

        # rewrite encoder.forward_chunk
        # <---------forward_chunk START--------->
        xs = self.global_cmvn(chunk_xs)
        # chunk mask is important for batch inferencing since
        # different sequence in a batch has different length
        xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
        cache_size = att_cache.size(3)  # required cache size
        masks = torch.cat((cache_mask, chunk_mask), dim=2)
        index = offset - cache_size

        pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
        pos_emb = pos_emb.to(dtype=xs.dtype)

        next_cache_start = -self.required_cache_size
        r_cache_mask = masks[:, :, next_cache_start:]

        r_att_cache = []
        r_cnn_cache = []
        for i, layer in enumerate(self.encoder.encoders):
            xs, _, new_att_cache, new_cnn_cache = layer(
                xs, masks, pos_emb,
                att_cache=att_cache[i],
                cnn_cache=cnn_cache[i])
            #   shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
            #   shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
            r_att_cache.append(new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
            if not self.transformer:
                r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
        if self.encoder.normalize_before:
            chunk_out = self.encoder.after_norm(xs)

        r_att_cache = torch.cat(r_att_cache, dim=1)  # concat on layers idx
        if not self.transformer:
            r_cnn_cache = torch.cat(r_cnn_cache, dim=1)  # concat on layers

        # <---------forward_chunk END--------->

        log_ctc_probs = self.ctc.log_softmax(chunk_out)
        log_probs, log_probs_idx = torch.topk(log_ctc_probs,
                                              self.beam_size,
                                              dim=2)
        log_probs = log_probs.to(chunk_xs.dtype)

        r_offset = offset + chunk_out.shape[1]
        # the below ops not supported in Tensorrt
        # chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
        #                   rounding_mode='floor')
        chunk_out_lens = chunk_lens // self.subsampling_rate
        r_offset = r_offset.unsqueeze(1)

        return log_probs, log_probs_idx, chunk_out, chunk_out_lens, \
            r_offset, r_att_cache, r_cnn_cache, r_cache_mask