Ejemplo n.º 1
0
    def encode_streaming(self, xs, params, task='ys'):
        """Simulate streaming encoding. Decoding is performed in the offline mode.
        Args:
            xs (FloatTensor): `[B, T, idim]`
            params (dict): hyper-parameters for decoding
            task (str): task to evaluate
        Returns:
            eout (FloatTensor): `[B, T, idim]`
            elens (IntTensor): `[B]`

        """
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        streaming = Streaming(xs[0], params, self.enc)

        self.enc.reset_cache()
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            streaming.cache_eout(eout_block)
            streaming.next_block()
            if is_last_block:
                break

        eout = streaming.pop_eouts()
        elens = torch.IntTensor([eout.size(1)])

        return eout, elens
Ejemplo n.º 2
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        """Simulate streaming decoding. Both encoding and decoding are performed in the online mode."""
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0
        block_sync = params['recog_block_sync']
        block_size = params['recog_block_sync_size']  # before subsampling

        streaming = Streaming(xs[0], params, self.enc)
        factor = self.enc.subsampling_factor
        block_size //= factor

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        # with torch.no_grad():
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            if is_reset:
                self.enc.reset_cache()
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            is_reset = False  # detect the first boundary in the same block

            # CTC-based VAD
            if streaming.is_ctc_vad:
                if self.ctc_weight_sub1 > 0:
                    ctc_probs_block = self.dec_fwd_sub1.ctc_probs(
                        eout_block_dict['ys_sub1']['xs'])
                    # TODO: consider subsampling
                else:
                    ctc_probs_block = self.dec_fwd.ctc_probs(eout_block)
                is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout)

            # Truncate the most right frames
            if is_reset and not is_last_block and streaming.bd_offset >= 0:
                eout_block = eout_block[:, :streaming.bd_offset]
            streaming.cache_eout(eout_block)

            # Block-synchronous attention decoding
            if isinstance(self.dec_fwd, RNNT):
                raise NotImplementedError
            elif isinstance(self.dec_fwd, RNNDecoder) and block_sync:
                for i in range(math.ceil(eout_block.size(1) / block_size)):
                    eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                              block_size]
                    end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync(
                        eout_block_i,
                        params,
                        idx2token,
                        hyps,
                        lm,
                        state_carry_over=False)
                merged_hyps = sorted(end_hyps + hyps,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if not is_reset:
                            streaming._bd_offset = eout_block.size(1) - 1
                            # TODO: fix later
                            is_reset = True
                    if len(best_hyp_id_prefix) > 0:
                        print(
                            '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s'
                            % (streaming.offset + eout_block.size(1) * factor,
                               self.dec_fwd.n_frames * factor,
                               streaming.n_blanks * factor,
                               idx2token(best_hyp_id_prefix)))
            elif isinstance(self.dec_fwd, TransformerDecoder):
                best_hyp_id_prefix = []
                raise NotImplementedError

            if is_reset:
                # Global decoding over the segmented region
                if not block_sync:
                    eout = streaming.pop_eouts()
                    elens = torch.IntTensor([eout.size(1)])
                    ctc_log_probs = None
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout))
                    nbest_hyps_id_offline = self.dec_fwd.beam_search(
                        eout,
                        elens,
                        global_params,
                        idx2token,
                        lm,
                        lm_second,
                        ctc_log_probs=ctc_log_probs,
                        exclude_eos=exclude_eos)[0]

                # pick up the best hyp from ended and active hypotheses
                if block_sync:
                    if len(best_hyp_id_prefix) > 0:
                        best_hyp_id_stream.extend(best_hyp_id_prefix)
                else:
                    if len(nbest_hyps_id_offline[0][0]) > 0:
                        best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

                # reset
                streaming.reset(stdout=stdout)
                hyps = None

            streaming.next_block()
            if is_last_block:
                break
            # next block will start from the frame next to the boundary
            streaming.backoff(x_block, self.dec_fwd, stdout=stdout)

        # Global decoding for tail blocks
        if not block_sync and streaming.n_cache_block > 0:
            eout = streaming.pop_eouts()
            elens = torch.IntTensor([eout.size(1)])
            nbest_hyps_id_offline = self.dec_fwd.beam_search(
                eout,
                elens,
                global_params,
                idx2token,
                lm,
                lm_second,
                exclude_eos=exclude_eos)[0]
            if len(nbest_hyps_id_offline[0][0]) > 0:
                best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

        # pick up the best hyp
        if not is_reset and block_sync and len(best_hyp_id_prefix) > 0:
            best_hyp_id_stream.extend(best_hyp_id_prefix)

        if len(best_hyp_id_stream) > 0:
            return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
        else:
            return [[[]]], [None]
Ejemplo n.º 3
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        from neural_sp.models.seq2seq.frontends.streaming import Streaming

        # check configurations
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0

        streaming = Streaming(xs[0], params, self.enc, idx2token)

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first chunk

        self.eval()
        with torch.no_grad():
            lm = getattr(self, 'lm_fwd', None)
            lm_second = getattr(self, 'lm_second', None)

            while True:
                # Encode input features chunk by chunk
                x_chunk, is_last_chunk = streaming.extract_feature()
                eout_chunk = self.encode([x_chunk],
                                         task,
                                         use_cache=not is_reset,
                                         streaming=True)[task]['xs']
                is_reset = False  # detect the first boundary in the same chunk

                # CTC-based VAD
                ctc_log_probs_chunk = None
                if streaming.is_ctc_vad:
                    ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk)
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs_chunk = torch.log(ctc_probs_chunk)
                    is_reset = streaming.ctc_vad(ctc_probs_chunk)

                # Truncate the most right frames
                if is_reset and not is_last_chunk:
                    eout_chunk = eout_chunk[:, :streaming.bd_offset + 1]
                streaming.eout_chunks.append(eout_chunk)

                # Chunk-synchronous attention decoding
                if params['recog_chunk_sync']:
                    end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_chunk_sync(
                        eout_chunk,
                        params,
                        idx2token,
                        lm,
                        ctc_log_probs=ctc_log_probs_chunk,
                        hyps=hyps,
                        state_carry_over=False,
                        ignore_eos=self.enc.rnn_type in ['lstm', 'conv_lstm'])
                    merged_hyps = sorted(end_hyps + hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current chunk is segmented.
                        if not is_reset:
                            streaming.bd_offset = eout_chunk.size(1) - 1
                            is_reset = True
                    print(
                        '\rSync MoChA (T:%d, offset:%d, blank:%d frames): %s' %
                        (streaming.offset +
                         eout_chunk.size(1) * streaming.factor,
                         self.dec_fwd.n_frames * streaming.factor,
                         streaming.n_blanks * streaming.factor,
                         idx2token(best_hyp_id_prefix)),
                        end='')
                    # print('-' * 50)

                if is_reset:
                    # Global decoding over the segmented region
                    if not params['recog_chunk_sync']:
                        eout = torch.cat(streaming.eout_chunks, dim=1)
                        elens = torch.IntTensor([eout.size(1)])
                        ctc_log_probs = None
                        if params['recog_ctc_weight'] > 0:
                            ctc_log_probs = torch.log(
                                self.dec_fwd.ctc_probs(eout))
                        nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                            eout,
                            elens,
                            global_params,
                            idx2token,
                            lm,
                            lm_second,
                            ctc_log_probs=ctc_log_probs)
                        # print('Offline MoChA (T:%d): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        idx2token(nbest_hyps_id_offline[0][0])))
                    eout = torch.cat(streaming.eout_chunks, dim=1)
                    elens = torch.IntTensor([eout.size(1)])
                    ctc_log_probs = None
                    nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                        eout,
                        elens,
                        global_params,
                        idx2token,
                        lm,
                        lm_second,
                        ctc_log_probs=ctc_log_probs)
                    # print('Offline MoChA (T:%d): %s' %
                    #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                    #        idx2token(nbest_hyps_id_offline[0][0])))

                    # pick up the best hyp from ended and active hypotheses
                    if not params['recog_chunk_sync']:
                        if len(nbest_hyps_id_offline[0][0]) > 0:
                            best_hyp_id_stream.extend(
                                nbest_hyps_id_offline[0][0])
                    else:
                        if len(best_hyp_id_prefix) > 0:
                            best_hyp_id_stream.extend(best_hyp_id_prefix)
                        # print('Final Sync MoChA (T:%d, segment:%d frames): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        # print('-' * 50)
                        # for test
                        # eos_hyp = np.zeros(1, dtype=np.int32)
                        # eos_hyp[0] = self.eos
                        # best_hyp_id_stream.extend(eos_hyp)

                    # reset
                    streaming.reset()
                    hyps = None

                    # next chunk will start from the frame next to the boundary
                    if not is_last_chunk and 0 <= streaming.bd_offset * streaming.factor < streaming.N_l - 1:
                        streaming.offset -= x_chunk[
                            (streaming.bd_offset + 1) *
                            streaming.factor:streaming.N_l].shape[0]
                        self.dec_fwd.n_frames -= x_chunk[
                            (streaming.bd_offset + 1) * streaming.
                            factor:streaming.N_l].shape[0] // streaming.factor
                        # print('Back %d frames' % (x_chunk[(streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0]))

                streaming.next_chunk()
                if is_last_chunk:
                    break

            # Global decoding over the last chunk
            if not params['recog_chunk_sync'] and len(
                    streaming.eout_chunks) > 0:
                eout = torch.cat(streaming.eout_chunks, dim=1)
                elens = torch.IntTensor([eout.size(1)])
                nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                    eout, elens, global_params, idx2token, lm, lm_second, None)
                # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0]))
                # print('*' * 50)
                if len(nbest_hyps_id_offline[0][0]) > 0:
                    best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

            # pick up the best hyp
            if not is_reset and params['recog_chunk_sync'] and len(
                    best_hyp_id_prefix) > 0:
                best_hyp_id_stream.extend(best_hyp_id_prefix)

            if len(best_hyp_id_stream) > 0:
                return [np.stack(best_hyp_id_stream, axis=0)], [None]
            else:
                return [[]], [None]
Ejemplo n.º 4
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0
        block_sync = params['recog_block_sync']

        streaming = Streaming(xs[0], params, self.enc,
                              params['recog_block_sync_size'])

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        with torch.no_grad():
            lm = getattr(self, 'lm_fwd', None)
            lm_second = getattr(self, 'lm_second', None)

            while True:
                # Encode input features block by block
                x_chunk, is_last_block, lookback, lookahead = streaming.extract_feature(
                )
                if is_reset:
                    self.enc.reset_cache()
                eout_chunk_dict = self.encode([x_chunk],
                                              'all',
                                              streaming=True,
                                              lookback=lookback,
                                              lookahead=lookahead)
                eout_chunk = eout_chunk_dict[task]['xs']
                is_reset = False  # detect the first boundary in the same block

                # CTC-based VAD
                if streaming.is_ctc_vad:
                    if self.ctc_weight_sub1 > 0:
                        ctc_probs_chunk = self.dec_fwd_sub1.ctc_probs(
                            eout_chunk_dict['ys_sub1']['xs'])
                        # TODO: consider subsampling
                    else:
                        ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk)
                    is_reset = streaming.ctc_vad(ctc_probs_chunk,
                                                 stdout=stdout)

                # Truncate the most right frames
                if is_reset and not is_last_block and streaming.bd_offset >= 0:
                    eout_chunk = eout_chunk[:, :streaming.bd_offset]
                streaming.eout_chunks.append(eout_chunk)

                # Chunk-synchronous attention decoding
                if isinstance(self.dec_fwd, RNNDecoder) and block_sync:
                    end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_block_sync(
                        eout_chunk,
                        params,
                        idx2token,
                        lm,
                        hyps=hyps,
                        state_carry_over=False)
                    merged_hyps = sorted(end_hyps + hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if not is_reset:
                            streaming.bd_offset = eout_chunk.size(1) - 1
                            is_reset = True
                    if len(best_hyp_id_prefix) > 0:
                        # print('\rStreaming (T:%d [frame], offset:%d [frame], blank:%d [frame]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        streaming.n_blanks * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        print('\r%s' % (idx2token(best_hyp_id_prefix)))

                if is_reset:
                    # Global decoding over the segmented region
                    if not block_sync:
                        eout = torch.cat(streaming.eout_chunks, dim=1)
                        elens = torch.IntTensor([eout.size(1)])
                        ctc_log_probs = None
                        if params['recog_ctc_weight'] > 0:
                            ctc_log_probs = torch.log(
                                self.dec_fwd.ctc_probs(eout))
                        nbest_hyps_id_offline = self.dec_fwd.beam_search(
                            eout,
                            elens,
                            global_params,
                            idx2token,
                            lm,
                            lm_second,
                            ctc_log_probs=ctc_log_probs)[0]
                        # print('Offline (T:%d [10ms]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        idx2token(nbest_hyps_id_offline[0][0])))

                    # pick up the best hyp from ended and active hypotheses
                    if not block_sync:
                        if len(nbest_hyps_id_offline[0][0]) > 0:
                            best_hyp_id_stream.extend(
                                nbest_hyps_id_offline[0][0])
                    else:
                        if len(best_hyp_id_prefix) > 0:
                            best_hyp_id_stream.extend(best_hyp_id_prefix)
                        # print('Final (T:%d [10ms], offset:%d [10ms]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        # print('-' * 50)
                        # for test
                        # eos_hyp = np.zeros(1, dtype=np.int32)
                        # eos_hyp[0] = self.eos
                        # best_hyp_id_stream.extend(eos_hyp)

                    # reset
                    streaming.reset(stdout=stdout)
                    hyps = None

                streaming.next_chunk()
                # next block will start from the frame next to the boundary
                if not is_last_block:
                    streaming.backoff(x_chunk, self.dec_fwd, stdout=stdout)
                if is_last_block:
                    break

            # Global decoding for tail chunks
            if not block_sync and len(streaming.eout_chunks) > 0:
                eout = torch.cat(streaming.eout_chunks, dim=1)
                elens = torch.IntTensor([eout.size(1)])
                nbest_hyps_id_offline = self.dec_fwd.beam_search(
                    eout, elens, global_params, idx2token, lm, lm_second)[0]
                # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0]))
                # print('*' * 50)
                if len(nbest_hyps_id_offline[0][0]) > 0:
                    best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

            # pick up the best hyp
            if not is_reset and block_sync and len(best_hyp_id_prefix) > 0:
                best_hyp_id_stream.extend(best_hyp_id_prefix)

            if len(best_hyp_id_stream) > 0:
                return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
            else:
                return [[[]]], [None]
Ejemplo n.º 5
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode."""
        block_size = params.get('recog_block_sync_size')  # before subsampling
        cache_emb = params.get('recog_cache_embedding')
        ctc_weight = params.get('recog_ctc_weight')

        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0 or self.ctc_weight == 1.0
        assert len(xs) == 1  # batch size
        assert params.get('recog_block_sync')
        # assert params.get('recog_length_norm')

        streaming = Streaming(xs[0], params, self.enc)
        factor = self.enc.subsampling_factor
        block_size //= factor
        assert block_size >= 1, "block_size is too small."

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        helper = BeamSearch(params.get('recog_beam_width'), self.eos,
                            params.get('recog_ctc_weight'),
                            params.get('recog_lm_weight'), self.device)
        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'),
                                        cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(
            lm_second, params.get('recog_lm_second_weight'), cache_emb)

        # cache token embeddings
        if cache_emb and self.fwd_weight > 0:
            self.dec_fwd.cache_embedding(self.device)

        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            if is_reset:
                self.enc.reset_cache()
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            is_reset = False  # detect the first boundary in the same block

            # CTC-based VAD
            if streaming.is_ctc_vad:
                if self.ctc_weight_sub1 > 0:
                    ctc_probs_block = self.dec_fwd_sub1.ctc.probs(
                        eout_block_dict['ys_sub1']['xs'])
                    # TODO: consider subsampling
                else:
                    ctc_probs_block = self.dec_fwd.ctc.probs(eout_block)
                is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout)

            # Truncate the most right frames
            if is_reset and not is_last_block and streaming.bd_offset >= 0:
                eout_block = eout_block[:, :streaming.bd_offset]
            streaming.cache_eout(eout_block)

            # Block-synchronous decoding
            if ctc_weight == 1 or self.ctc_weight == 1:
                end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync(
                    eout_block, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, RNNT):
                end_hyps, hyps = self.dec_fwd.beam_search_block_sync(
                    eout_block, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, RNNDecoder):
                for i in range(math.ceil(eout_block.size(1) / block_size)):
                    eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                              block_size]
                    end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync(
                        eout_block_i, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, TransformerDecoder):
                raise NotImplementedError
            else:
                raise NotImplementedError(self.dec_fwd)

            merged_hyps = sorted(end_hyps + hyps,
                                 key=lambda x: x['score'],
                                 reverse=True)
            if len(merged_hyps) > 0:
                best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])

                if len(hyps) == 0 or (len(best_hyp_id_prefix) > 0
                                      and best_hyp_id_prefix[-1] == self.eos):
                    # reset beam if <eos> is generated from the best hypothesis
                    best_hyp_id_prefix = best_hyp_id_prefix[:
                                                            -1]  # exclude <eos>
                    # Segmentation strategy 2:
                    # If <eos> is emitted from the decoder (not CTC),
                    # the current block is segmented.
                    if not is_reset:
                        streaming._bd_offset = eout_block.size(
                            1) - 1  # TODO: fix later
                        is_reset = True

                if len(best_hyp_id_prefix) > 0:
                    n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
                    print(
                        '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s'
                        % (streaming.offset + eout_block.size(1) * factor,
                           n_frames * factor, streaming.n_blanks * factor,
                           idx2token(best_hyp_id_prefix)))

            if is_reset:

                # pick up the best hyp from ended and active hypotheses
                if len(best_hyp_id_prefix) > 0:
                    best_hyp_id_stream.extend(best_hyp_id_prefix)

                # reset
                streaming.reset(stdout=stdout)
                hyps = None

            streaming.next_block()
            if is_last_block:
                break
            # next block will start from the frame next to the boundary
            streaming.backoff(x_block, self.dec_fwd, stdout=stdout)

        # pick up the best hyp
        if not is_reset and len(best_hyp_id_prefix) > 0:
            best_hyp_id_stream.extend(best_hyp_id_prefix)

        if len(best_hyp_id_stream) > 0:
            return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
        else:
            return [[[]]], [None]
Ejemplo n.º 6
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         speaker=None,
                         task='ys'):
        """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode."""
        self.eval()
        block_size = params.get('recog_block_sync_size')  # before subsampling
        cache_emb = params.get('recog_cache_embedding')
        ctc_weight = params.get('recog_ctc_weight')
        backoff = True

        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0 or self.ctc_weight == 1.0
        assert len(xs) == 1  # batch size
        assert params.get('recog_block_sync')
        # assert params.get('recog_length_norm')

        streaming = Streaming(xs[0], params, self.enc, idx2token)
        factor = self.enc.subsampling_factor
        block_size //= factor
        assert block_size >= 1, "block_size is too small."
        is_transformer_enc = 'former' in self.enc.enc_type

        hyps = None
        hyps_nobd = []
        best_hyp_id_session = []
        is_reset = False

        helper = BeamSearch(params.get('recog_beam_width'), self.eos,
                            params.get('recog_ctc_weight'),
                            params.get('recog_lm_weight'), self.device)

        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'),
                                        cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(
            lm_second, params.get('recog_lm_second_weight'), cache_emb)

        # cache token embeddings
        if cache_emb and self.fwd_weight > 0:
            self.dec_fwd.cache_embedding(self.device)

        self.enc.reset_cache()
        eout_block_tail = None
        x_block_prev, xlen_block_prev = None, None
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat(
            )
            if not is_transformer_enc and is_reset:
                self.enc.reset_cache()
                if backoff:
                    self.encode([x_block_prev],
                                'all',
                                streaming=True,
                                cnn_lookback=cnn_lookback,
                                cnn_lookahead=cnn_lookahead,
                                xlen_block=xlen_block_prev)
            x_block_prev = x_block
            xlen_block_prev = xlen_block
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            if eout_block_tail is not None:
                eout_block = torch.cat([eout_block_tail, eout_block], dim=1)
                eout_block_tail = None

            if eout_block.size(1) > 0:
                streaming.cache_eout(eout_block)

                # Block-synchronous decoding
                if ctc_weight == 1 or self.ctc_weight == 1:
                    end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNT):
                    end_hyps, hyps = self.dec_fwd.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNDecoder):
                    n_frames = getattr(self.dec_fwd, 'n_frames', 0)
                    for i in range(math.ceil(eout_block.size(1) / block_size)):
                        eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                                  block_size]
                        end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync(
                            eout_block_i,
                            params,
                            helper,
                            idx2token,
                            hyps,
                            hyps_nobd,
                            lm,
                            speaker=speaker)
                elif isinstance(self.dec_fwd, TransformerDecoder):
                    raise NotImplementedError
                else:
                    raise NotImplementedError(self.dec_fwd)

                # CTC-based reset point detection
                is_reset = False
                if streaming.enable_ctc_reset_point_detection:
                    if self.ctc_weight_sub1 > 0:
                        ctc_probs_block = self.dec_fwd_sub1.ctc.probs(
                            eout_block_dict['ys_sub1']['xs'])
                        # TODO: consider subsampling
                    else:
                        ctc_probs_block = self.dec_fwd.ctc.probs(eout_block)
                    is_reset = streaming.ctc_reset_point_detection(
                        ctc_probs_block)

                merged_hyps = sorted(end_hyps + hyps + hyps_nobd,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    best_hyp_id_prefix_viz = np.array(
                        merged_hyps[0]['hyp'][1:])

                    if (len(best_hyp_id_prefix) > 0
                            and best_hyp_id_prefix[-1] == self.eos):
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Condition 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if (not is_reset) and (not streaming.safeguard_reset):
                            is_reset = True

                    if len(best_hyp_id_prefix_viz) > 0:
                        n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
                        print(
                            '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s'
                            %
                            ((streaming.offset + eout_block.size(1) * factor) /
                             100, n_frames * factor, streaming.n_blanks *
                             factor, idx2token(best_hyp_id_prefix_viz)))

            if is_reset:

                # pick up the best hyp from ended and active hypotheses
                if len(best_hyp_id_prefix) > 0:
                    best_hyp_id_session.extend(best_hyp_id_prefix)

                # reset
                streaming.reset()
                hyps = None
                hyps_nobd = []

            streaming.next_block()
            if is_last_block:
                break

        # pick up the best hyp
        if not is_reset and len(best_hyp_id_prefix) > 0:
            best_hyp_id_session.extend(best_hyp_id_prefix)

        if len(best_hyp_id_session) > 0:
            return [[np.stack(best_hyp_id_session, axis=0)]], [None]
        else:
            return [[[]]], [None]