예제 #1
0
    def encode_streaming(self, xs, params, task='ys'):
        """Simulate streaming encoding. Decoding is performed in the offline mode.
        Args:
            xs (FloatTensor): `[B, T, idim]`
            params (dict): hyper-parameters for decoding
            task (str): task to evaluate
        Returns:
            eout (FloatTensor): `[B, T, idim]`
            elens (IntTensor): `[B]`

        """
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        streaming = Streaming(xs[0], params, self.enc)

        self.enc.reset_cache()
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            streaming.cache_eout(eout_block)
            streaming.next_block()
            if is_last_block:
                break

        eout = streaming.pop_eouts()
        elens = torch.IntTensor([eout.size(1)])

        return eout, elens
예제 #2
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        """Simulate streaming decoding. Both encoding and decoding are performed in the online mode."""
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0
        block_sync = params['recog_block_sync']
        block_size = params['recog_block_sync_size']  # before subsampling

        streaming = Streaming(xs[0], params, self.enc)
        factor = self.enc.subsampling_factor
        block_size //= factor

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        # with torch.no_grad():
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            if is_reset:
                self.enc.reset_cache()
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            is_reset = False  # detect the first boundary in the same block

            # CTC-based VAD
            if streaming.is_ctc_vad:
                if self.ctc_weight_sub1 > 0:
                    ctc_probs_block = self.dec_fwd_sub1.ctc_probs(
                        eout_block_dict['ys_sub1']['xs'])
                    # TODO: consider subsampling
                else:
                    ctc_probs_block = self.dec_fwd.ctc_probs(eout_block)
                is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout)

            # Truncate the most right frames
            if is_reset and not is_last_block and streaming.bd_offset >= 0:
                eout_block = eout_block[:, :streaming.bd_offset]
            streaming.cache_eout(eout_block)

            # Block-synchronous attention decoding
            if isinstance(self.dec_fwd, RNNT):
                raise NotImplementedError
            elif isinstance(self.dec_fwd, RNNDecoder) and block_sync:
                for i in range(math.ceil(eout_block.size(1) / block_size)):
                    eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                              block_size]
                    end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync(
                        eout_block_i,
                        params,
                        idx2token,
                        hyps,
                        lm,
                        state_carry_over=False)
                merged_hyps = sorted(end_hyps + hyps,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if not is_reset:
                            streaming._bd_offset = eout_block.size(1) - 1
                            # TODO: fix later
                            is_reset = True
                    if len(best_hyp_id_prefix) > 0:
                        print(
                            '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s'
                            % (streaming.offset + eout_block.size(1) * factor,
                               self.dec_fwd.n_frames * factor,
                               streaming.n_blanks * factor,
                               idx2token(best_hyp_id_prefix)))
            elif isinstance(self.dec_fwd, TransformerDecoder):
                best_hyp_id_prefix = []
                raise NotImplementedError

            if is_reset:
                # Global decoding over the segmented region
                if not block_sync:
                    eout = streaming.pop_eouts()
                    elens = torch.IntTensor([eout.size(1)])
                    ctc_log_probs = None
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout))
                    nbest_hyps_id_offline = self.dec_fwd.beam_search(
                        eout,
                        elens,
                        global_params,
                        idx2token,
                        lm,
                        lm_second,
                        ctc_log_probs=ctc_log_probs,
                        exclude_eos=exclude_eos)[0]

                # pick up the best hyp from ended and active hypotheses
                if block_sync:
                    if len(best_hyp_id_prefix) > 0:
                        best_hyp_id_stream.extend(best_hyp_id_prefix)
                else:
                    if len(nbest_hyps_id_offline[0][0]) > 0:
                        best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

                # reset
                streaming.reset(stdout=stdout)
                hyps = None

            streaming.next_block()
            if is_last_block:
                break
            # next block will start from the frame next to the boundary
            streaming.backoff(x_block, self.dec_fwd, stdout=stdout)

        # Global decoding for tail blocks
        if not block_sync and streaming.n_cache_block > 0:
            eout = streaming.pop_eouts()
            elens = torch.IntTensor([eout.size(1)])
            nbest_hyps_id_offline = self.dec_fwd.beam_search(
                eout,
                elens,
                global_params,
                idx2token,
                lm,
                lm_second,
                exclude_eos=exclude_eos)[0]
            if len(nbest_hyps_id_offline[0][0]) > 0:
                best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

        # pick up the best hyp
        if not is_reset and block_sync and len(best_hyp_id_prefix) > 0:
            best_hyp_id_stream.extend(best_hyp_id_prefix)

        if len(best_hyp_id_stream) > 0:
            return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
        else:
            return [[[]]], [None]
예제 #3
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode."""
        block_size = params.get('recog_block_sync_size')  # before subsampling
        cache_emb = params.get('recog_cache_embedding')
        ctc_weight = params.get('recog_ctc_weight')

        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0 or self.ctc_weight == 1.0
        assert len(xs) == 1  # batch size
        assert params.get('recog_block_sync')
        # assert params.get('recog_length_norm')

        streaming = Streaming(xs[0], params, self.enc)
        factor = self.enc.subsampling_factor
        block_size //= factor
        assert block_size >= 1, "block_size is too small."

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        helper = BeamSearch(params.get('recog_beam_width'), self.eos,
                            params.get('recog_ctc_weight'),
                            params.get('recog_lm_weight'), self.device)
        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'),
                                        cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(
            lm_second, params.get('recog_lm_second_weight'), cache_emb)

        # cache token embeddings
        if cache_emb and self.fwd_weight > 0:
            self.dec_fwd.cache_embedding(self.device)

        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            if is_reset:
                self.enc.reset_cache()
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            is_reset = False  # detect the first boundary in the same block

            # CTC-based VAD
            if streaming.is_ctc_vad:
                if self.ctc_weight_sub1 > 0:
                    ctc_probs_block = self.dec_fwd_sub1.ctc.probs(
                        eout_block_dict['ys_sub1']['xs'])
                    # TODO: consider subsampling
                else:
                    ctc_probs_block = self.dec_fwd.ctc.probs(eout_block)
                is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout)

            # Truncate the most right frames
            if is_reset and not is_last_block and streaming.bd_offset >= 0:
                eout_block = eout_block[:, :streaming.bd_offset]
            streaming.cache_eout(eout_block)

            # Block-synchronous decoding
            if ctc_weight == 1 or self.ctc_weight == 1:
                end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync(
                    eout_block, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, RNNT):
                end_hyps, hyps = self.dec_fwd.beam_search_block_sync(
                    eout_block, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, RNNDecoder):
                for i in range(math.ceil(eout_block.size(1) / block_size)):
                    eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                              block_size]
                    end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync(
                        eout_block_i, params, helper, idx2token, hyps, lm)
            elif isinstance(self.dec_fwd, TransformerDecoder):
                raise NotImplementedError
            else:
                raise NotImplementedError(self.dec_fwd)

            merged_hyps = sorted(end_hyps + hyps,
                                 key=lambda x: x['score'],
                                 reverse=True)
            if len(merged_hyps) > 0:
                best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])

                if len(hyps) == 0 or (len(best_hyp_id_prefix) > 0
                                      and best_hyp_id_prefix[-1] == self.eos):
                    # reset beam if <eos> is generated from the best hypothesis
                    best_hyp_id_prefix = best_hyp_id_prefix[:
                                                            -1]  # exclude <eos>
                    # Segmentation strategy 2:
                    # If <eos> is emitted from the decoder (not CTC),
                    # the current block is segmented.
                    if not is_reset:
                        streaming._bd_offset = eout_block.size(
                            1) - 1  # TODO: fix later
                        is_reset = True

                if len(best_hyp_id_prefix) > 0:
                    n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
                    print(
                        '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s'
                        % (streaming.offset + eout_block.size(1) * factor,
                           n_frames * factor, streaming.n_blanks * factor,
                           idx2token(best_hyp_id_prefix)))

            if is_reset:

                # pick up the best hyp from ended and active hypotheses
                if len(best_hyp_id_prefix) > 0:
                    best_hyp_id_stream.extend(best_hyp_id_prefix)

                # reset
                streaming.reset(stdout=stdout)
                hyps = None

            streaming.next_block()
            if is_last_block:
                break
            # next block will start from the frame next to the boundary
            streaming.backoff(x_block, self.dec_fwd, stdout=stdout)

        # pick up the best hyp
        if not is_reset and len(best_hyp_id_prefix) > 0:
            best_hyp_id_stream.extend(best_hyp_id_prefix)

        if len(best_hyp_id_stream) > 0:
            return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
        else:
            return [[[]]], [None]
예제 #4
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         speaker=None,
                         task='ys'):
        """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode."""
        self.eval()
        block_size = params.get('recog_block_sync_size')  # before subsampling
        cache_emb = params.get('recog_cache_embedding')
        ctc_weight = params.get('recog_ctc_weight')
        backoff = True

        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0 or self.ctc_weight == 1.0
        assert len(xs) == 1  # batch size
        assert params.get('recog_block_sync')
        # assert params.get('recog_length_norm')

        streaming = Streaming(xs[0], params, self.enc, idx2token)
        factor = self.enc.subsampling_factor
        block_size //= factor
        assert block_size >= 1, "block_size is too small."
        is_transformer_enc = 'former' in self.enc.enc_type

        hyps = None
        hyps_nobd = []
        best_hyp_id_session = []
        is_reset = False

        helper = BeamSearch(params.get('recog_beam_width'), self.eos,
                            params.get('recog_ctc_weight'),
                            params.get('recog_lm_weight'), self.device)

        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'),
                                        cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(
            lm_second, params.get('recog_lm_second_weight'), cache_emb)

        # cache token embeddings
        if cache_emb and self.fwd_weight > 0:
            self.dec_fwd.cache_embedding(self.device)

        self.enc.reset_cache()
        eout_block_tail = None
        x_block_prev, xlen_block_prev = None, None
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat(
            )
            if not is_transformer_enc and is_reset:
                self.enc.reset_cache()
                if backoff:
                    self.encode([x_block_prev],
                                'all',
                                streaming=True,
                                cnn_lookback=cnn_lookback,
                                cnn_lookahead=cnn_lookahead,
                                xlen_block=xlen_block_prev)
            x_block_prev = x_block
            xlen_block_prev = xlen_block
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            if eout_block_tail is not None:
                eout_block = torch.cat([eout_block_tail, eout_block], dim=1)
                eout_block_tail = None

            if eout_block.size(1) > 0:
                streaming.cache_eout(eout_block)

                # Block-synchronous decoding
                if ctc_weight == 1 or self.ctc_weight == 1:
                    end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNT):
                    end_hyps, hyps = self.dec_fwd.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNDecoder):
                    n_frames = getattr(self.dec_fwd, 'n_frames', 0)
                    for i in range(math.ceil(eout_block.size(1) / block_size)):
                        eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                                  block_size]
                        end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync(
                            eout_block_i,
                            params,
                            helper,
                            idx2token,
                            hyps,
                            hyps_nobd,
                            lm,
                            speaker=speaker)
                elif isinstance(self.dec_fwd, TransformerDecoder):
                    raise NotImplementedError
                else:
                    raise NotImplementedError(self.dec_fwd)

                # CTC-based reset point detection
                is_reset = False
                if streaming.enable_ctc_reset_point_detection:
                    if self.ctc_weight_sub1 > 0:
                        ctc_probs_block = self.dec_fwd_sub1.ctc.probs(
                            eout_block_dict['ys_sub1']['xs'])
                        # TODO: consider subsampling
                    else:
                        ctc_probs_block = self.dec_fwd.ctc.probs(eout_block)
                    is_reset = streaming.ctc_reset_point_detection(
                        ctc_probs_block)

                merged_hyps = sorted(end_hyps + hyps + hyps_nobd,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    best_hyp_id_prefix_viz = np.array(
                        merged_hyps[0]['hyp'][1:])

                    if (len(best_hyp_id_prefix) > 0
                            and best_hyp_id_prefix[-1] == self.eos):
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Condition 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if (not is_reset) and (not streaming.safeguard_reset):
                            is_reset = True

                    if len(best_hyp_id_prefix_viz) > 0:
                        n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
                        print(
                            '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s'
                            %
                            ((streaming.offset + eout_block.size(1) * factor) /
                             100, n_frames * factor, streaming.n_blanks *
                             factor, idx2token(best_hyp_id_prefix_viz)))

            if is_reset:

                # pick up the best hyp from ended and active hypotheses
                if len(best_hyp_id_prefix) > 0:
                    best_hyp_id_session.extend(best_hyp_id_prefix)

                # reset
                streaming.reset()
                hyps = None
                hyps_nobd = []

            streaming.next_block()
            if is_last_block:
                break

        # pick up the best hyp
        if not is_reset and len(best_hyp_id_prefix) > 0:
            best_hyp_id_session.extend(best_hyp_id_prefix)

        if len(best_hyp_id_session) > 0:
            return [[np.stack(best_hyp_id_session, axis=0)]], [None]
        else:
            return [[[]]], [None]