Ejemplo n.º 1
0
    def encode_streaming(self, xs, params, task='ys'):
        """Simulate streaming encoding. Decoding is performed in the offline mode.
        Args:
            xs (FloatTensor): `[B, T, idim]`
            params (dict): hyper-parameters for decoding
            task (str): task to evaluate
        Returns:
            eout (FloatTensor): `[B, T, idim]`
            elens (IntTensor): `[B]`

        """
        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.fwd_weight > 0
        assert len(xs) == 1  # batch size
        streaming = Streaming(xs[0], params, self.enc)

        self.enc.reset_cache()
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat(
            )
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            streaming.cache_eout(eout_block)
            streaming.next_block()
            if is_last_block:
                break

        eout = streaming.pop_eouts()
        elens = torch.IntTensor([eout.size(1)])

        return eout, elens
Ejemplo n.º 2
0
    def decode_streaming(self,
                         xs,
                         params,
                         idx2token,
                         exclude_eos=False,
                         speaker=None,
                         task='ys'):
        """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode."""
        self.eval()
        block_size = params.get('recog_block_sync_size')  # before subsampling
        cache_emb = params.get('recog_cache_embedding')
        ctc_weight = params.get('recog_ctc_weight')
        backoff = True

        assert task == 'ys'
        assert self.input_type == 'speech'
        assert self.ctc_weight > 0
        assert self.fwd_weight > 0 or self.ctc_weight == 1.0
        assert len(xs) == 1  # batch size
        assert params.get('recog_block_sync')
        # assert params.get('recog_length_norm')

        streaming = Streaming(xs[0], params, self.enc, idx2token)
        factor = self.enc.subsampling_factor
        block_size //= factor
        assert block_size >= 1, "block_size is too small."
        is_transformer_enc = 'former' in self.enc.enc_type

        hyps = None
        hyps_nobd = []
        best_hyp_id_session = []
        is_reset = False

        helper = BeamSearch(params.get('recog_beam_width'), self.eos,
                            params.get('recog_ctc_weight'),
                            params.get('recog_lm_weight'), self.device)

        lm = getattr(self, 'lm_fwd', None)
        lm_second = getattr(self, 'lm_second', None)
        lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'),
                                        cache_emb)
        if lm is not None:
            assert isinstance(lm, RNNLM)
        lm_second = helper.verify_lm_eval_mode(
            lm_second, params.get('recog_lm_second_weight'), cache_emb)

        # cache token embeddings
        if cache_emb and self.fwd_weight > 0:
            self.dec_fwd.cache_embedding(self.device)

        self.enc.reset_cache()
        eout_block_tail = None
        x_block_prev, xlen_block_prev = None, None
        while True:
            # Encode input features block by block
            x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat(
            )
            if not is_transformer_enc and is_reset:
                self.enc.reset_cache()
                if backoff:
                    self.encode([x_block_prev],
                                'all',
                                streaming=True,
                                cnn_lookback=cnn_lookback,
                                cnn_lookahead=cnn_lookahead,
                                xlen_block=xlen_block_prev)
            x_block_prev = x_block
            xlen_block_prev = xlen_block
            eout_block_dict = self.encode([x_block],
                                          'all',
                                          streaming=True,
                                          cnn_lookback=cnn_lookback,
                                          cnn_lookahead=cnn_lookahead,
                                          xlen_block=xlen_block)
            eout_block = eout_block_dict[task]['xs']
            if eout_block_tail is not None:
                eout_block = torch.cat([eout_block_tail, eout_block], dim=1)
                eout_block_tail = None

            if eout_block.size(1) > 0:
                streaming.cache_eout(eout_block)

                # Block-synchronous decoding
                if ctc_weight == 1 or self.ctc_weight == 1:
                    end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNT):
                    end_hyps, hyps = self.dec_fwd.beam_search_block_sync(
                        eout_block, params, helper, idx2token, hyps, lm)
                elif isinstance(self.dec_fwd, RNNDecoder):
                    n_frames = getattr(self.dec_fwd, 'n_frames', 0)
                    for i in range(math.ceil(eout_block.size(1) / block_size)):
                        eout_block_i = eout_block[:, i * block_size:(i + 1) *
                                                  block_size]
                        end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync(
                            eout_block_i,
                            params,
                            helper,
                            idx2token,
                            hyps,
                            hyps_nobd,
                            lm,
                            speaker=speaker)
                elif isinstance(self.dec_fwd, TransformerDecoder):
                    raise NotImplementedError
                else:
                    raise NotImplementedError(self.dec_fwd)

                # CTC-based reset point detection
                is_reset = False
                if streaming.enable_ctc_reset_point_detection:
                    if self.ctc_weight_sub1 > 0:
                        ctc_probs_block = self.dec_fwd_sub1.ctc.probs(
                            eout_block_dict['ys_sub1']['xs'])
                        # TODO: consider subsampling
                    else:
                        ctc_probs_block = self.dec_fwd.ctc.probs(eout_block)
                    is_reset = streaming.ctc_reset_point_detection(
                        ctc_probs_block)

                merged_hyps = sorted(end_hyps + hyps + hyps_nobd,
                                     key=lambda x: x['score'],
                                     reverse=True)
                if len(merged_hyps) > 0:
                    best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:])
                    best_hyp_id_prefix_viz = np.array(
                        merged_hyps[0]['hyp'][1:])

                    if (len(best_hyp_id_prefix) > 0
                            and best_hyp_id_prefix[-1] == self.eos):
                        # reset beam if <eos> is generated from the best hypothesis
                        best_hyp_id_prefix = best_hyp_id_prefix[:
                                                                -1]  # exclude <eos>
                        # Condition 2:
                        # If <eos> is emitted from the decoder (not CTC),
                        # the current block is segmented.
                        if (not is_reset) and (not streaming.safeguard_reset):
                            is_reset = True

                    if len(best_hyp_id_prefix_viz) > 0:
                        n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames
                        print(
                            '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s'
                            %
                            ((streaming.offset + eout_block.size(1) * factor) /
                             100, n_frames * factor, streaming.n_blanks *
                             factor, idx2token(best_hyp_id_prefix_viz)))

            if is_reset:

                # pick up the best hyp from ended and active hypotheses
                if len(best_hyp_id_prefix) > 0:
                    best_hyp_id_session.extend(best_hyp_id_prefix)

                # reset
                streaming.reset()
                hyps = None
                hyps_nobd = []

            streaming.next_block()
            if is_last_block:
                break

        # pick up the best hyp
        if not is_reset and len(best_hyp_id_prefix) > 0:
            best_hyp_id_session.extend(best_hyp_id_prefix)

        if len(best_hyp_id_session) > 0:
            return [[np.stack(best_hyp_id_session, axis=0)]], [None]
        else:
            return [[[]]], [None]