Example #1
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        from neural_sp.models.seq2seq.frontends.streaming import Streaming

        # check configurations
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0

        streaming = Streaming(xs[0], params, self.enc, idx2token)

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first chunk

        self.eval()
        with torch.no_grad():
            lm = getattr(self, 'lm_fwd', None)
            lm_second = getattr(self, 'lm_second', None)

            while True:
                # Encode input features chunk by chunk
                x_chunk, is_last_chunk = streaming.extract_feature()
                eout_chunk = self.encode([x_chunk],
                                         task,
                                         use_cache=not is_reset,
                                         streaming=True)[task]['xs']
                is_reset = False  # detect the first boundary in the same chunk

                # CTC-based VAD
                ctc_log_probs_chunk = None
                if streaming.is_ctc_vad:
                    ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk)
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs_chunk = torch.log(ctc_probs_chunk)
                    is_reset = streaming.ctc_vad(ctc_probs_chunk)

                # Truncate the most right frames
                if is_reset and not is_last_chunk:
                    eout_chunk = eout_chunk[:, :streaming.bd_offset + 1]
                streaming.eout_chunks.append(eout_chunk)

                # Chunk-synchronous attention decoding
                if params['recog_chunk_sync']:
                    end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_chunk_sync(
                        eout_chunk,
                        params,
                        idx2token,
                        lm,
                        ctc_log_probs=ctc_log_probs_chunk,
                        hyps=hyps,
                        state_carry_over=False,
                        ignore_eos=self.enc.rnn_type in ['lstm', 'conv_lstm'])
                    merged_hyps = sorted(end_hyps + hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current chunk is segmented.
                        if not is_reset:
                            streaming.bd_offset = eout_chunk.size(1) - 1
                            is_reset = True
                    print(
                        '\rSync MoChA (T:%d, offset:%d, blank:%d frames): %s' %
                        (streaming.offset +
                         eout_chunk.size(1) * streaming.factor,
                         self.dec_fwd.n_frames * streaming.factor,
                         streaming.n_blanks * streaming.factor,
                         idx2token(best_hyp_id_prefix)),
                        end='')
                    # print('-' * 50)

                if is_reset:
                    # Global decoding over the segmented region
                    if not params['recog_chunk_sync']:
                        eout = torch.cat(streaming.eout_chunks, dim=1)
                        elens = torch.IntTensor([eout.size(1)])
                        ctc_log_probs = None
                        if params['recog_ctc_weight'] > 0:
                            ctc_log_probs = torch.log(
                                self.dec_fwd.ctc_probs(eout))
                        nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                            eout,
                            elens,
                            global_params,
                            idx2token,
                            lm,
                            lm_second,
                            ctc_log_probs=ctc_log_probs)
                        # print('Offline MoChA (T:%d): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        idx2token(nbest_hyps_id_offline[0][0])))
                    eout = torch.cat(streaming.eout_chunks, dim=1)
                    elens = torch.IntTensor([eout.size(1)])
                    ctc_log_probs = None
                    nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                        eout,
                        elens,
                        global_params,
                        idx2token,
                        lm,
                        lm_second,
                        ctc_log_probs=ctc_log_probs)
                    # print('Offline MoChA (T:%d): %s' %
                    #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                    #        idx2token(nbest_hyps_id_offline[0][0])))

                    # pick up the best hyp from ended and active hypotheses
                    if not params['recog_chunk_sync']:
                        if len(nbest_hyps_id_offline[0][0]) > 0:
                            best_hyp_id_stream.extend(
                                nbest_hyps_id_offline[0][0])
                    else:
                        if len(best_hyp_id_prefix) > 0:
                            best_hyp_id_stream.extend(best_hyp_id_prefix)
                        # print('Final Sync MoChA (T:%d, segment:%d frames): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        # print('-' * 50)
                        # for test
                        # eos_hyp = np.zeros(1, dtype=np.int32)
                        # eos_hyp[0] = self.eos
                        # best_hyp_id_stream.extend(eos_hyp)

                    # reset
                    streaming.reset()
                    hyps = None

                    # next chunk will start from the frame next to the boundary
                    if not is_last_chunk and 0 <= streaming.bd_offset * streaming.factor < streaming.N_l - 1:
                        streaming.offset -= x_chunk[
                            (streaming.bd_offset + 1) *
                            streaming.factor:streaming.N_l].shape[0]
                        self.dec_fwd.n_frames -= x_chunk[
                            (streaming.bd_offset + 1) * streaming.
                            factor:streaming.N_l].shape[0] // streaming.factor
                        # print('Back %d frames' % (x_chunk[(streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0]))

                streaming.next_chunk()
                if is_last_chunk:
                    break

            # Global decoding over the last chunk
            if not params['recog_chunk_sync'] and len(
                    streaming.eout_chunks) > 0:
                eout = torch.cat(streaming.eout_chunks, dim=1)
                elens = torch.IntTensor([eout.size(1)])
                nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search(
                    eout, elens, global_params, idx2token, lm, lm_second, None)
                # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0]))
                # print('*' * 50)
                if len(nbest_hyps_id_offline[0][0]) > 0:
                    best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

            # pick up the best hyp
            if not is_reset and params['recog_chunk_sync'] and len(
                    best_hyp_id_prefix) > 0:
                best_hyp_id_stream.extend(best_hyp_id_prefix)

            if len(best_hyp_id_stream) > 0:
                return [np.stack(best_hyp_id_stream, axis=0)], [None]
            else:
                return [[]], [None]
Example #2
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0
        block_sync = params['recog_block_sync']

        streaming = Streaming(xs[0], params, self.enc,
                              params['recog_block_sync_size'])

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        with torch.no_grad():
            lm = getattr(self, 'lm_fwd', None)
            lm_second = getattr(self, 'lm_second', None)

            while True:
                # Encode input features block by block
                x_chunk, is_last_block, lookback, lookahead = streaming.extract_feature(
                )
                if is_reset:
                    self.enc.reset_cache()
                eout_chunk_dict = self.encode([x_chunk],
                                              'all',
                                              streaming=True,
                                              lookback=lookback,
                                              lookahead=lookahead)
                eout_chunk = eout_chunk_dict[task]['xs']
                is_reset = False  # detect the first boundary in the same block

                # CTC-based VAD
                if streaming.is_ctc_vad:
                    if self.ctc_weight_sub1 > 0:
                        ctc_probs_chunk = self.dec_fwd_sub1.ctc_probs(
                            eout_chunk_dict['ys_sub1']['xs'])
                        # TODO: consider subsampling
                    else:
                        ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk)
                    is_reset = streaming.ctc_vad(ctc_probs_chunk,
                                                 stdout=stdout)

                # Truncate the most right frames
                if is_reset and not is_last_block and streaming.bd_offset >= 0:
                    eout_chunk = eout_chunk[:, :streaming.bd_offset]
                streaming.eout_chunks.append(eout_chunk)

                # Chunk-synchronous attention decoding
                if isinstance(self.dec_fwd, RNNDecoder) and block_sync:
                    end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_block_sync(
                        eout_chunk,
                        params,
                        idx2token,
                        lm,
                        hyps=hyps,
                        state_carry_over=False)
                    merged_hyps = sorted(end_hyps + hyps,
                                         key=lambda x: x['score'],
                                         reverse=True)
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if not is_reset:
                            streaming.bd_offset = eout_chunk.size(1) - 1
                            is_reset = True
                    if len(best_hyp_id_prefix) > 0:
                        # print('\rStreaming (T:%d [frame], offset:%d [frame], blank:%d [frame]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        streaming.n_blanks * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        print('\r%s' % (idx2token(best_hyp_id_prefix)))

                if is_reset:
                    # Global decoding over the segmented region
                    if not block_sync:
                        eout = torch.cat(streaming.eout_chunks, dim=1)
                        elens = torch.IntTensor([eout.size(1)])
                        ctc_log_probs = None
                        if params['recog_ctc_weight'] > 0:
                            ctc_log_probs = torch.log(
                                self.dec_fwd.ctc_probs(eout))
                        nbest_hyps_id_offline = self.dec_fwd.beam_search(
                            eout,
                            elens,
                            global_params,
                            idx2token,
                            lm,
                            lm_second,
                            ctc_log_probs=ctc_log_probs)[0]
                        # print('Offline (T:%d [10ms]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        idx2token(nbest_hyps_id_offline[0][0])))

                    # pick up the best hyp from ended and active hypotheses
                    if not block_sync:
                        if len(nbest_hyps_id_offline[0][0]) > 0:
                            best_hyp_id_stream.extend(
                                nbest_hyps_id_offline[0][0])
                    else:
                        if len(best_hyp_id_prefix) > 0:
                            best_hyp_id_stream.extend(best_hyp_id_prefix)
                        # print('Final (T:%d [10ms], offset:%d [10ms]): %s' %
                        #       (streaming.offset + eout_chunk.size(1) * streaming.factor,
                        #        self.dec_fwd.n_frames * streaming.factor,
                        #        idx2token(best_hyp_id_prefix)))
                        # print('-' * 50)
                        # for test
                        # eos_hyp = np.zeros(1, dtype=np.int32)
                        # eos_hyp[0] = self.eos
                        # best_hyp_id_stream.extend(eos_hyp)

                    # reset
                    streaming.reset(stdout=stdout)
                    hyps = None

                streaming.next_chunk()
                # next block will start from the frame next to the boundary
                if not is_last_block:
                    streaming.backoff(x_chunk, self.dec_fwd, stdout=stdout)
                if is_last_block:
                    break

            # Global decoding for tail chunks
            if not block_sync and len(streaming.eout_chunks) > 0:
                eout = torch.cat(streaming.eout_chunks, dim=1)
                elens = torch.IntTensor([eout.size(1)])
                nbest_hyps_id_offline = self.dec_fwd.beam_search(
                    eout, elens, global_params, idx2token, lm, lm_second)[0]
                # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0]))
                # print('*' * 50)
                if len(nbest_hyps_id_offline[0][0]) > 0:
                    best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

            # pick up the best hyp
            if not is_reset and block_sync and len(best_hyp_id_prefix) > 0:
                best_hyp_id_stream.extend(best_hyp_id_prefix)

            if len(best_hyp_id_stream) > 0:
                return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
            else:
                return [[[]]], [None]