Пример #1
0
 def ctc_greedy_search(self,
                       speech: torch.Tensor,
                       speech_lengths: torch.Tensor,
                       decoding_chunk_size: int = -1) -> List[List[int]]:
     '''
     param: speech: (batch, max_len, feat_dim)
     param: speech_length: (batch, )
     param: decoding_chunk_size
             <0: for decoding, use full chunk.
             >0: for decoding, use fixed chunk size as set.
             0: used for training, it's prohibited here
     return:
         best path result, without remove blank and duplicates
     '''
     assert speech.shape[0] == speech_lengths.shape[0]
     assert decoding_chunk_size != 0
     device = speech.device
     batch_size = speech.shape[0]
     # Let's assume B = batch_size
     encoder_out, encoder_mask = self.encoder(
         speech, speech_lengths, decoding_chunk_size=decoding_chunk_size
     )  # (B, maxlen, encoder_dim)
     maxlen = encoder_out.size(1)
     encoder_out_lens = encoder_mask.squeeze(1).sum(1)
     ctc_probs = self.ctc.log_softmax(
         encoder_out)  # (B, maxlen, vocab_size)
     topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
     topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
     mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
     topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
     hyps = [hyp.tolist() for hyp in topk_index]
     hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
     return hyps
Пример #2
0
    def ctc_greedy_search(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        decoding_chunk_size: int = -1,
        num_decoding_left_chunks: int = -1,
        simulate_streaming: bool = False,
    ) -> List[List[int]]:
        """ Apply CTC greedy search

        Args:
            speech (torch.Tensor): (batch, max_len, feat_dim)
            speech_length (torch.Tensor): (batch, )
            beam_size (int): beam size for beam search
            decoding_chunk_size (int): decoding chunk for dynamic chunk
                trained model.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
                0: used for training, it's prohibited here
            simulate_streaming (bool): whether do encoder forward in a
                streaming fashion
        Returns:
            List[List[int]]: best path result
        """
        assert speech.shape[0] == speech_lengths.shape[0]
        assert decoding_chunk_size != 0
        batch_size = speech.shape[0]
        #print("speech shape:",speech.shape,"speech_lengths:",speech_lengths)
        # Let's assume B = batch_size
        encoder_out, encoder_mask = self._forward_encoder(
            speech, speech_lengths, decoding_chunk_size,
            num_decoding_left_chunks,
            simulate_streaming)  # (B, maxlen, encoder_dim)
        maxlen = encoder_out.size(1)
        #print("maxlen:",maxlen)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        #print("encoder_out_lens:",encoder_out_lens)
        ctc_probs = self.ctc.log_softmax(
            encoder_out)  # (B, maxlen, vocab_size)
        topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
        ##print("topk_index:",topk_index.shape,"topk_prob:",topk_prob.shape)
        topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
        #print("topk_index:",topk_index)
        mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
        #print("mask:",mask)
        topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
        hyps = [hyp.tolist() for hyp in topk_index]
        hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
        return hyps
Пример #3
0
    def ctc_greedy_search(self,
                          speech: torch.Tensor,
                          speech_lengths: torch.Tensor,
                          decoding_chunk_size: int = -1) -> List[List[int]]:
        """ Apply CTC greedy search

        Args:
            speech (torch.Tensor): (batch, max_len, feat_dim)
            speech_length (torch.Tensor): (batch, )
            beam_size (int): beam size for beam search
            decoding_chunk_size (int): decoding chunk for dynamic chunk
                trained model.
                <0: for decoding, use full chunk.
                >0: for decoding, use fixed chunk size as set.
                0: used for training, it's prohibited here

        Returns:
            List[List[int]]: best path result
        """
        assert speech.shape[0] == speech_lengths.shape[0]
        assert decoding_chunk_size != 0
        batch_size = speech.shape[0]
        # Let's assume B = batch_size
        encoder_out, encoder_mask = self.encoder(
            speech, speech_lengths, decoding_chunk_size=decoding_chunk_size
        )  # (B, maxlen, encoder_dim)
        maxlen = encoder_out.size(1)
        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
        ctc_probs = self.ctc.log_softmax(
            encoder_out)  # (B, maxlen, vocab_size)
        topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
        topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
        mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
        topk_index = topk_index.masked_fill_(mask, self.eos)  # (B, maxlen)
        hyps = [hyp.tolist() for hyp in topk_index]
        hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
        return hyps
Пример #4
0
 target_length = target_length.to(device)
 # Let's assume B = batch_size and N = beam_size
 # 1. Encoder
 encoder_out, encoder_mask = model._forward_encoder(
     feat, feats_length)  # (B, maxlen, encoder_dim)
 maxlen = encoder_out.size(1)
 batch_size = encoder_out.size(0)
 ctc_probs = model.ctc.log_softmax(
     encoder_out)  # (1, maxlen, vocab_size)
 encoder_out_lens = encoder_mask.squeeze(1).sum(1)
 topk_prob, topk_index = ctc_probs.topk(1, dim=2)  # (B, maxlen, 1)
 topk_index = topk_index.view(batch_size, maxlen)  # (B, maxlen)
 mask = make_pad_mask(encoder_out_lens)  # (B, maxlen)
 topk_index = topk_index.masked_fill_(mask, eos)  # (B, maxlen)
 alignment = [hyp.tolist() for hyp in topk_index]
 hyps = [remove_duplicates_and_blank(hyp) for hyp in alignment]
 for index, i in enumerate(key):
     content = []
     if len(hyps[index]) > 0:
         for w in hyps[index]:
             if w == eos:
                 break
             content.append(char_dict[w])
     f_ctc_results.write('{} {}\n'.format(i, " ".join(content)))
 f_ctc_results.flush()
 for index, i in enumerate(key):
     timestamp = get_frames_timestamp(alignment[index])
     subsample = get_subsample(configs)
     word_seq, word_time = get_labformat_frames(
         timestamp, subsample, char_dict)
     for index_j in range(len(word_seq)):