Exemple #1
0
    def encode_streaming(self, xs, params, task='ys'):
        """Simulate streaming encoding. Decoding is performed in the offline mode.
        Args:
            xs (FloatTensor): `[B, T, idim]`
            params (dict): hyper-parameters for decoding
            task (str): task to evaluate
        Returns:
            eout (FloatTensor): `[B, T, idim]`
            elens (IntTensor): `[B]`

        """
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        streaming = Streaming(xs[0], params, self.enc)

        self.enc.reset_cache()
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            streaming.cache_eout(eout_block)
            streaming.next_block()
            if is_last_block:
                break

        eout = streaming.pop_eouts()
        elens = torch.IntTensor([eout.size(1)])

        return eout, elens
Exemple #2
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         task='ys'):
        """Simulate streaming decoding. Both encoding and decoding are performed in the online mode."""
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        # assert params['recog_length_norm']
        global_params = copy.deepcopy(params)
        global_params['recog_max_len_ratio'] = 1.0
        block_sync = params['recog_block_sync']
        block_size = params['recog_block_sync_size']  # before subsampling

        streaming = Streaming(xs[0], params, self.enc)
        factor = self.enc.subsampling_factor
        block_size //= factor

        hyps = None
        best_hyp_id_stream = []
        is_reset = True  # for the first block

        stdout = False

        self.eval()
        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        # with torch.no_grad():
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature(
            )
            if is_reset:
                self.enc.reset_cache()
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            is_reset = False  # detect the first boundary in the same block

            # CTC-based VAD
            if streaming.is_ctc_vad:
                if self.ctc_weight_sub1 > 0:
                    ctc_probs_block = self.dec_fwd_sub1.ctc_probs(
                        eout_block_dict['ys_sub1']['xs'])
                    # TODO: consider subsampling
                else:
                    ctc_probs_block = self.dec_fwd.ctc_probs(eout_block)
                is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout)

            # Truncate the most right frames
            if is_reset and not is_last_block and streaming.bd_offset >= 0:
                eout_block = eout_block[:, :streaming.bd_offset]
            streaming.cache_eout(eout_block)

            # Block-synchronous attention decoding
            if isinstance(self.dec_fwd, RNNT):
                raise NotImplementedError
            elif isinstance(self.dec_fwd, RNNDecoder) and block_sync:
                for i in range(math.ceil(eout_block.size(1) / block_size)):
                    eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                              block_size]
                    end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync(
                        eout_block_i,
                        params,
                        idx2token,
                        hyps,
                        lm,
                        state_carry_over=False)
                merged_hyps = sorted(end_hyps + hyps,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    if len(best_hyp_id_prefix
                           ) > 0 and best_hyp_id_prefix[-1] == self.eos:
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Segmentation strategy 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if not is_reset:
                            streaming._bd_offset = eout_block.size(1) - 1
                            # TODO: fix later
                            is_reset = True
                    if len(best_hyp_id_prefix) > 0:
                        print(
                            '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s'
                            % (streaming.offset + eout_block.size(1) * factor,
                               self.dec_fwd.n_frames * factor,
                               streaming.n_blanks * factor,
                               idx2token(best_hyp_id_prefix)))
            elif isinstance(self.dec_fwd, TransformerDecoder):
                best_hyp_id_prefix = []
                raise NotImplementedError

            if is_reset:
                # Global decoding over the segmented region
                if not block_sync:
                    eout = streaming.pop_eouts()
                    elens = torch.IntTensor([eout.size(1)])
                    ctc_log_probs = None
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout))
                    nbest_hyps_id_offline = self.dec_fwd.beam_search(
                        eout,
                        elens,
                        global_params,
                        idx2token,
                        lm,
                        lm_second,
                        ctc_log_probs=ctc_log_probs,
                        exclude_eos=exclude_eos)[0]

                # pick up the best hyp from ended and active hypotheses
                if block_sync:
                    if len(best_hyp_id_prefix) > 0:
                        best_hyp_id_stream.extend(best_hyp_id_prefix)
                else:
                    if len(nbest_hyps_id_offline[0][0]) > 0:
                        best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

                # reset
                streaming.reset(stdout=stdout)
                hyps = None

            streaming.next_block()
            if is_last_block:
                break
            # next block will start from the frame next to the boundary
            streaming.backoff(x_block, self.dec_fwd, stdout=stdout)

        # Global decoding for tail blocks
        if not block_sync and streaming.n_cache_block > 0:
            eout = streaming.pop_eouts()
            elens = torch.IntTensor([eout.size(1)])
            nbest_hyps_id_offline = self.dec_fwd.beam_search(
                eout,
                elens,
                global_params,
                idx2token,
                lm,
                lm_second,
                exclude_eos=exclude_eos)[0]
            if len(nbest_hyps_id_offline[0][0]) > 0:
                best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0])

        # pick up the best hyp
        if not is_reset and block_sync and len(best_hyp_id_prefix) > 0:
            best_hyp_id_stream.extend(best_hyp_id_prefix)

        if len(best_hyp_id_stream) > 0:
            return [[np.stack(best_hyp_id_stream, axis=0)]], [None]
        else:
            return [[[]]], [None]