Beispiel #1
0
    def decode(self,
               xs,
               params,
               idx2token,
               exclude_eos=False,
               refs_id=None,
               refs=None,
               utt_ids=None,
               speakers=None,
               task='ys',
               ensemble_models=[]):
        """Decoding in the inference stage.

        Args:
            xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]`
            params (dict): hyper-parameters for decoding
                beam_width (int): the size of beam
                min_len_ratio (float):
                max_len_ratio (float):
                len_penalty (float): length penalty
                cov_penalty (float): coverage penalty
                cov_threshold (float): threshold for coverage penalty
                lm_weight (float): the weight of RNNLM score
                resolving_unk (bool): not used (to make compatible)
                fwd_bwd_attention (bool):
            idx2token (): converter from index to token
            exclude_eos (bool): exclude <eos> from best_hyps_id
            refs_id (list): gold token IDs to compute log likelihood
            refs (list): gold transcriptions
            utt_ids (list):
            speakers (list):
            task (str): ys* or ys_sub1* or ys_sub2*
            ensemble_models (list): list of Speech2Text classes
        Returns:
            best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]`
            aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]`

        """
        if task.split('.')[0] == 'ys':
            dir = 'bwd' if self.bwd_weight > 0 and params[
                'recog_bwd_attention'] else 'fwd'
        elif task.split('.')[0] == 'ys_sub1':
            dir = 'fwd_sub1'
        elif task.split('.')[0] == 'ys_sub2':
            dir = 'fwd_sub2'
        else:
            raise ValueError(task)

        self.eval()
        with torch.no_grad():
            # Encode input features
            if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir:
                eout_dict = self.encode(xs, task)
            else:
                eout_dict = self.encode(xs, task)

            # CTC
            if (self.fwd_weight == 0 and self.bwd_weight == 0) or (
                    self.ctc_weight > 0 and params['recog_ctc_weight'] == 1):
                lm = getattr(self, 'lm_' + dir, None)
                lm_second = getattr(self, 'lm_second', None)
                lm_second_bwd = None  # TODO

                best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc(
                    eout_dict[task]['xs'], eout_dict[task]['xlens'], params,
                    idx2token, lm, lm_second, lm_second_bwd, 1, refs_id,
                    utt_ids, speakers)
                return best_hyps_id, None

            # Attention
            elif params['recog_beam_width'] == 1 and not params[
                    'recog_fwd_bwd_attention']:
                best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy(
                    eout_dict[task]['xs'], eout_dict[task]['xlens'],
                    params['recog_max_len_ratio'], idx2token, exclude_eos,
                    refs_id, utt_ids, speakers)
            else:
                assert params['recog_batch_size'] == 1

                ctc_log_probs = None
                if params['recog_ctc_weight'] > 0:
                    ctc_log_probs = self.dec_fwd.ctc_log_probs(
                        eout_dict[task]['xs'])

                # forward-backward decoding
                if params['recog_fwd_bwd_attention']:
                    lm_fwd = getattr(self, 'lm_fwd', None)
                    lm_bwd = getattr(self, 'lm_bwd', None)

                    # forward decoder
                    nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search(
                        eout_dict[task]['xs'], eout_dict[task]['xlens'],
                        params, idx2token, lm_fwd, None, lm_bwd, ctc_log_probs,
                        params['recog_beam_width'], False, refs_id, utt_ids,
                        speakers)

                    # backward decoder
                    nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search(
                        eout_dict[task]['xs'], eout_dict[task]['xlens'],
                        params, idx2token, lm_bwd, None, lm_fwd, ctc_log_probs,
                        params['recog_beam_width'], False, refs_id, utt_ids,
                        speakers)

                    # forward-backward attention
                    best_hyps_id = fwd_bwd_attention(
                        nbest_hyps_id_fwd, aws_fwd, scores_fwd,
                        nbest_hyps_id_bwd, aws_bwd, scores_bwd, self.eos,
                        params['recog_gnmt_decoding'],
                        params['recog_length_penalty'], idx2token, refs_id)
                    aws = None
                else:
                    # ensemble
                    ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], []
                    if len(ensemble_models) > 0:
                        for i_e, model in enumerate(ensemble_models):
                            if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir:
                                enc_outs_e = model.encode(xs, task)
                            else:
                                enc_outs_e = model.encode(xs, task)
                            ensmbl_eouts += [enc_outs_e[task]['xs']]
                            ensmbl_elens += [enc_outs_e[task]['xlens']]
                            ensmbl_decs += [getattr(model, 'dec_' + dir)]
                            # NOTE: only support for the main task now

                    lm = getattr(self, 'lm_' + dir, None)
                    lm_second = getattr(self, 'lm_second', None)
                    lm_bwd = getattr(self,
                                     'lm_bwd' if dir == 'fwd' else 'lm_bwd',
                                     None)

                    nbest_hyps_id, aws, scores = getattr(
                        self, 'dec_' + dir).beam_search(
                            eout_dict[task]['xs'], eout_dict[task]['xlens'],
                            params, idx2token, lm, lm_second, lm_bwd,
                            ctc_log_probs, 1, exclude_eos, refs_id, utt_ids,
                            speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs)
                    best_hyps_id = [hyp[0] for hyp in nbest_hyps_id]

            return best_hyps_id, aws
Beispiel #2
0
    def decode(self,
               xs,
               params,
               idx2token,
               exclude_eos=False,
               refs_id=None,
               refs=None,
               utt_ids=None,
               speakers=None,
               task='ys',
               ensemble_models=[],
               trigger_points=None,
               teacher_force=False):
        """Decode in the inference stage.

        Args:
            xs (List): length `[B]`, which contains arrays of size `[T, input_dim]`
            params (dict): hyper-parameters for decoding
            idx2token (): converter from index to token
            exclude_eos (bool): exclude <eos> from best_hyps_id
            refs_id (List): gold token IDs to compute log likelihood
            refs (List): gold transcriptions
            utt_ids (List): utterance id list
            speakers (List): speaker list
            task (str): ys* or ys_sub1* or ys_sub2*
            ensemble_models (List): Speech2Text classes
            trigger_points (np.ndarray): `[B, L]`
            teacher_force (bool): conduct teacher-forcing
        Returns:
            nbest_hyps_id (List[List[np.ndarray]]): length `[B]`, which contains a list of length `[n_best]` which contains arrays of size `[L]`
            aws (List[np.ndarray]): length `[B]`, which contains arrays of size `[L, T, n_heads]`

        """
        if task.split('.')[0] == 'ys':
            dir = 'bwd' if self.bwd_weight > 0 and params[
                'recog_bwd_attention'] else 'fwd'
        elif task.split('.')[0] == 'ys_sub1':
            dir = 'fwd_sub1'
        elif task.split('.')[0] == 'ys_sub2':
            dir = 'fwd_sub2'
        else:
            raise ValueError(task)

        if utt_ids is not None:
            if self.utt_id_prev != utt_ids[0]:
                self.reset_session()
            self.utt_id_prev = utt_ids[0]

        self.eval()
        with torch.no_grad():
            # Encode input features
            if params['recog_streaming_encoding']:
                eouts, elens = self.encode_streaming(xs, params, task)
            else:
                eout_dict = self.encode(xs, task)
                eouts = eout_dict[task]['xs']
                elens = eout_dict[task]['xlens']

            # CTC
            if (self.fwd_weight == 0 and self.bwd_weight == 0) or (
                    self.ctc_weight > 0 and params['recog_ctc_weight'] == 1):
                lm = getattr(self, 'lm_' + dir, None)
                lm_second = getattr(self, 'lm_second', None)
                lm_second_bwd = None  # TODO

                if params.get('recog_beam_width') == 1:
                    nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.greedy(
                        eouts, elens)
                else:
                    nbest_hyps_id = getattr(self,
                                            'dec_' + dir).ctc.beam_search(
                                                eouts, elens, params,
                                                idx2token, lm, lm_second,
                                                lm_second_bwd, 1, refs_id,
                                                utt_ids, speakers)
                return nbest_hyps_id, None

            # Attention/RNN-T
            elif params['recog_beam_width'] == 1 and not params[
                    'recog_fwd_bwd_attention']:
                best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy(
                    eouts, elens, params['recog_max_len_ratio'], idx2token,
                    exclude_eos, refs_id, utt_ids, speakers)
                nbest_hyps_id = [[hyp] for hyp in best_hyps_id]
            else:
                assert params['recog_batch_size'] == 1

                scores_ctc = None
                if params['recog_ctc_weight'] > 0:
                    scores_ctc = self.dec_fwd.ctc.scores(eouts)

                # forward-backward decoding
                if params['recog_fwd_bwd_attention']:
                    lm = getattr(self, 'lm_fwd', None)
                    lm_bwd = getattr(self, 'lm_bwd', None)

                    # forward decoder
                    nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search(
                        eouts, elens, params, idx2token, lm, None, lm_bwd,
                        scores_ctc, params['recog_beam_width'], False, refs_id,
                        utt_ids, speakers)

                    # backward decoder
                    nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search(
                        eouts, elens, params, idx2token, lm_bwd, None, lm,
                        scores_ctc, params['recog_beam_width'], False, refs_id,
                        utt_ids, speakers)

                    # forward-backward attention
                    best_hyps_id = fwd_bwd_attention(
                        nbest_hyps_id_fwd, aws_fwd, scores_fwd,
                        nbest_hyps_id_bwd, aws_bwd, scores_bwd, self.eos,
                        params['recog_gnmt_decoding'],
                        params['recog_length_penalty'], idx2token, refs_id)
                    nbest_hyps_id = [[hyp] for hyp in best_hyps_id]
                    aws = None
                else:
                    # ensemble
                    ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], []
                    if len(ensemble_models) > 0:
                        for i_e, model in enumerate(ensemble_models):
                            enc_outs_e = model.encode(xs, task)
                            ensmbl_eouts += [enc_outs_e[task]['xs']]
                            ensmbl_elens += [enc_outs_e[task]['xlens']]
                            ensmbl_decs += [getattr(model, 'dec_' + dir)]
                            # NOTE: only support for the main task now

                    lm = getattr(self, 'lm_' + dir, None)
                    lm_second = getattr(self, 'lm_second', None)
                    lm_bwd = getattr(self,
                                     'lm_bwd' if dir == 'fwd' else 'lm_bwd',
                                     None)

                    nbest_hyps_id, aws, scores = getattr(
                        self, 'dec_' + dir).beam_search(
                            eouts, elens, params, idx2token, lm, lm_second,
                            lm_bwd, scores_ctc, params['recog_beam_width'],
                            exclude_eos, refs_id, utt_ids, speakers,
                            ensmbl_eouts, ensmbl_elens, ensmbl_decs)

            return nbest_hyps_id, aws
Beispiel #3
0
    def decode(self,
               xs,
               params,
               idx2token,
               nbest=1,
               exclude_eos=False,
               refs_id=None,
               refs_text=None,
               utt_ids=None,
               speakers=None,
               task='ys',
               ensemble_models=[]):
        """Decoding in the inference stage.

        Args:
            xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]`
            params (dict): hyper-parameters for decoding
                beam_width (int): the size of beam
                min_len_ratio (float):
                max_len_ratio (float):
                len_penalty (float): length penalty
                cov_penalty (float): coverage penalty
                cov_threshold (float): threshold for coverage penalty
                lm_weight (float): the weight of RNNLM score
                resolving_unk (bool): not used (to make compatible)
                fwd_bwd_attention (bool):
            idx2token (): converter from index to token
            nbest (int):
            exclude_eos (bool): exclude <eos> from best_hyps_id
            refs_id (list): gold token IDs to compute log likelihood
            refs_text (list): gold transcriptions
            utt_ids (list):
            speakers (list):
            task (str): ys* or ys_sub1* or ys_sub2*
            ensemble_models (list): list of Speech2Text classes
        Returns:
            best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]`
            aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]`

        """
        self.eval()
        with torch.no_grad():
            if task.split('.')[0] == 'ys':
                dir = 'bwd' if self.bwd_weight > 0 and params[
                    'recog_bwd_attention'] else 'fwd'
            elif task.split('.')[0] == 'ys_sub1':
                dir = 'fwd_sub1'
            elif task.split('.')[0] == 'ys_sub2':
                dir = 'fwd_sub2'
            else:
                raise ValueError(task)

            # Encode input features
            if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir:
                enc_outs = self.encode(xs, task, flip=True)
            else:
                enc_outs = self.encode(xs, task, flip=False)

            #########################
            # CTC
            #########################
            if (self.fwd_weight == 0 and self.bwd_weight == 0) or (
                    self.ctc_weight > 0 and params['recog_ctc_weight'] == 1):
                lm = None
                if params['recog_lm_weight'] > 0 and hasattr(
                        self, 'lm_fwd') and self.lm_fwd is not None:
                    lm = getattr(self, 'lm_' + dir)

                best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc(
                    enc_outs[task]['xs'], enc_outs[task]['xlens'], params,
                    idx2token, lm, nbest, refs_id, utt_ids, speakers)
                return best_hyps_id, None, (None, None)

            #########################
            # Attention
            #########################
            else:
                cache_info = (None, None)

                if params['recog_beam_width'] == 1 and not params[
                        'recog_fwd_bwd_attention']:
                    best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy(
                        enc_outs[task]['xs'], enc_outs[task]['xlens'],
                        params['recog_max_len_ratio'], idx2token, exclude_eos,
                        refs_id, speakers, params['recog_oracle'])
                else:
                    assert params['recog_batch_size'] == 1

                    ctc_log_probs = None
                    if params['recog_ctc_weight'] > 0:
                        ctc_log_probs = self.dec_fwd.ctc_log_probs(
                            enc_outs[task]['xs'])

                    # forward-backward decoding
                    if params['recog_fwd_bwd_attention']:
                        # forward decoder
                        lm_fwd, lm_bwd = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_fwd') and self.lm_fwd is not None:
                            lm_fwd = self.lm_fwd
                            if params['recog_reverse_lm_rescoring'] and hasattr(
                                    self,
                                    'lm_bwd') and self.lm_bwd is not None:
                                lm_bwd = self.lm_bwd

                        # ensemble (forward)
                        ensmbl_eouts_fwd = []
                        ensmbl_elens_fwd = []
                        ensmbl_decs_fwd = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                enc_outs_e_fwd = model.encode(xs,
                                                              task,
                                                              flip=False)
                                ensmbl_eouts_fwd += [
                                    enc_outs_e_fwd[task]['xs']
                                ]
                                ensmbl_elens_fwd += [
                                    enc_outs_e_fwd[task]['xlens']
                                ]
                                ensmbl_decs_fwd += [model.dec_fwd]
                                # NOTE: only support for the main task now

                        nbest_hyps_id_fwd, aws_fwd, scores_fwd, cache_info = self.dec_fwd.beam_search(
                            enc_outs[task]['xs'], enc_outs[task]['xlens'],
                            params, idx2token, lm_fwd, lm_bwd, ctc_log_probs,
                            params['recog_beam_width'], False, refs_id,
                            utt_ids, speakers, ensmbl_eouts_fwd,
                            ensmbl_elens_fwd, ensmbl_decs_fwd)

                        # backward decoder
                        lm_bwd, lm_fwd = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_bwd') and self.lm_bwd is not None:
                            lm_bwd = self.lm_bwd
                            if params['recog_reverse_lm_rescoring'] and hasattr(
                                    self,
                                    'lm_fwd') and self.lm_fwd is not None:
                                lm_fwd = self.lm_fwd

                        # ensemble (backward)
                        ensmbl_eouts_bwd = []
                        ensmbl_elens_bwd = []
                        ensmbl_decs_bwd = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                if self.input_type == 'speech' and self.mtl_per_batch:
                                    enc_outs_e_bwd = model.encode(xs,
                                                                  task,
                                                                  flip=True)
                                else:
                                    enc_outs_e_bwd = model.encode(xs,
                                                                  task,
                                                                  flip=False)
                                ensmbl_eouts_bwd += [
                                    enc_outs_e_bwd[task]['xs']
                                ]
                                ensmbl_elens_bwd += [
                                    enc_outs_e_bwd[task]['xlens']
                                ]
                                ensmbl_decs_bwd += [model.dec_bwd]
                                # NOTE: only support for the main task now
                                # TODO(hirofumi): merge with the forward for the efficiency

                        flip = False
                        if self.input_type == 'speech' and self.mtl_per_batch:
                            flip = True
                            enc_outs_bwd = self.encode(xs, task, flip=True)
                        else:
                            enc_outs_bwd = enc_outs
                        nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search(
                            enc_outs_bwd[task]['xs'], enc_outs[task]['xlens'],
                            params, idx2token, lm_bwd, lm_fwd, ctc_log_probs,
                            params['recog_beam_width'], False, refs_id,
                            utt_ids, speakers, ensmbl_eouts_bwd,
                            ensmbl_elens_bwd, ensmbl_decs_bwd)

                        # forward-backward attention
                        best_hyps_id = fwd_bwd_attention(
                            nbest_hyps_id_fwd, aws_fwd, scores_fwd,
                            nbest_hyps_id_bwd, aws_bwd, scores_bwd, flip,
                            self.eos, params['recog_gnmt_decoding'],
                            params['recog_length_penalty'], idx2token, refs_id)
                        aws = None
                    else:
                        # ensemble
                        ensmbl_eouts = []
                        ensmbl_elens = []
                        ensmbl_decs = []
                        if len(ensemble_models) > 0:
                            for i_e, model in enumerate(ensemble_models):
                                if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir:
                                    enc_outs_e = model.encode(xs,
                                                              task,
                                                              flip=True)
                                else:
                                    enc_outs_e = model.encode(xs,
                                                              task,
                                                              flip=False)
                                ensmbl_eouts += [enc_outs_e[task]['xs']]
                                ensmbl_elens += [enc_outs_e[task]['xlens']]
                                ensmbl_decs += [getattr(model, 'dec_' + dir)]
                                # NOTE: only support for the main task now

                        lm, lm_rev = None, None
                        if params['recog_lm_weight'] > 0 and hasattr(
                                self, 'lm_' + dir) and getattr(
                                    self, 'lm_' + dir) is not None:
                            lm = getattr(self, 'lm_' + dir)
                            if params['recog_reverse_lm_rescoring']:
                                if dir == 'fwd':
                                    lm_rev = self.lm_bwd
                                else:
                                    raise NotImplementedError

                        nbest_hyps_id, aws, scores, cache_info = getattr(
                            self, 'dec_' + dir).beam_search(
                                enc_outs[task]['xs'], enc_outs[task]['xlens'],
                                params, idx2token, lm, lm_rev, ctc_log_probs,
                                nbest, exclude_eos, refs_id, utt_ids, speakers,
                                ensmbl_eouts, ensmbl_elens, ensmbl_decs)

                        if nbest == 1:
                            best_hyps_id = [hyp[0] for hyp in nbest_hyps_id]
                            aws = [aw[0]
                                   for aw in aws] if aws is not None else aws
                        else:
                            return nbest_hyps_id, aws, scores, cache_info
                        # NOTE: nbest >= 2 is used for MWER training only

                return best_hyps_id, aws, cache_info
    def decode(self,
               xs,
               params,
               idx2token,
               exclude_eos=False,
               refs_id=None,
               refs=None,
               utt_ids=None,
               speakers=None,
               task='ys',
               ensemble_models=[],
               trigger_points=None,
               teacher_force=False):
        """Decode in the inference stage.

        Args:
            xs (List): length `[B]`, which contains arrays of size `[T, input_dim]`
            params (dict): hyper-parameters for decoding
            idx2token (): converter from index to token
            exclude_eos (bool): exclude <eos> from best_hyps_id
            refs_id (List): gold token IDs to compute log likelihood
            refs (List): gold transcriptions
            utt_ids (List): utterance id list
            speakers (List): speaker list
            task (str): ys* or ys_sub1* or ys_sub2*
            ensemble_models (List): Speech2Text classes
            trigger_points (np.ndarray): `[B, L]`
            teacher_force (bool): conduct teacher-forcing
        Returns:
            nbest_hyps_id (List[List[np.ndarray]]): length `[B]`, which contains a list of length `[n_best]` which contains arrays of size `[L]`
            aws (List[np.ndarray]): length `[B]`, which contains arrays of size `[L, T, n_heads]`

        """
        self.eval()
        if task.split('.')[0] == 'ys':
            dir = 'bwd' if self.bwd_weight > 0 and params[
                'recog_bwd_attention'] else 'fwd'
        elif task.split('.')[0] == 'ys_sub1':
            dir = 'fwd_sub1'
        elif task.split('.')[0] == 'ys_sub2':
            dir = 'fwd_sub2'
        else:
            raise ValueError(task)

        if utt_ids is not None:
            if self.utt_id_prev != utt_ids[0]:
                self.reset_session()
            self.utt_id_prev = utt_ids[0]

        # Encode input features
        if params['recog_streaming_encoding']:
            eouts, elens = self.encode_streaming(xs, params, task)
        else:
            eout_dict = self.encode(xs, task)
            eouts = eout_dict[task]['xs']
            elens = eout_dict[task]['xlens']

        # CTC
        if (self.fwd_weight == 0 and self.bwd_weight == 0) or (
                self.ctc_weight > 0 and params['recog_ctc_weight'] == 1):
            lm = getattr(self, 'lm_' + dir, None)
            lm_second = getattr(self, 'lm_second', None)
            lm_second_bwd = None  # TODO

            if params.get('recog_beam_width') == 1:
                nbest_hyps_id = getattr(self,
                                        'dec_' + dir).ctc.greedy(eouts, elens)
            else:
                nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.beam_search(
                    eouts, elens, params, idx2token, lm, lm_second,
                    lm_second_bwd, 1, refs_id, utt_ids, speakers)
            return nbest_hyps_id, None

        # Attention/RNN-T
        elif params['recog_beam_width'] == 1 and not params[
                'recog_fwd_bwd_attention']:
            best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy(
                eouts, elens, params['recog_max_len_ratio'], idx2token,
                exclude_eos, refs_id, utt_ids, speakers)
            nbest_hyps_id = [[hyp] for hyp in best_hyps_id]
        elif self.is_wfst:  # TODO: config
            # print(eouts.shape)
            # assert False

            nbest_hyps_id = []
            bs = eouts.size(0)

            # nbest_hyps_id = Parallel(n_jobs=5)(
            #     delayed(self._wfst)(eouts[b].unsqueeze(0), dir) for b in range(bs)
            # )
            # aws = None
            # print(nbest_hyps_id)
            # assert False

            # for signel processing
            for b in range(bs):
                encode_out = eouts[b].unsqueeze(0)
                initial_packed_states = (0, )
                inference_one_step = getattr(self,
                                             'dec_' + dir).decode_wfst_onestep
                self.decoder.decode(encode_out, initial_packed_states,
                                    inference_one_step)
                words_prediction_id = self.decoder.get_best_path()
                words_prediction = ''.join(
                    [self.words[int(idx)] for idx in words_prediction_id])
                predictions = [
                    self.vocab_wfst[prediction]
                    for prediction in words_prediction
                ]
                # print(words_prediction_id)
                # print(words_prediction)
                # print(predictions)
                nbest_hyps_id.append([np.array(predictions)])
                # print(nbest_hyps_id)
                aws = None
            # assert False, 'check wfst decode'
        else:
            assert params['recog_batch_size'] == 1
            # print('okkk')
            scores_ctc = None
            if params['recog_ctc_weight'] > 0:
                scores_ctc = self.dec_fwd.ctc.scores(eouts)

            # forward-backward decoding
            if params['recog_fwd_bwd_attention']:
                lm = getattr(self, 'lm_fwd', None)
                lm_bwd = getattr(self, 'lm_bwd', None)

                # forward decoder
                nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search(
                    eouts, elens, params, idx2token, lm, None, lm_bwd,
                    scores_ctc, params['recog_beam_width'], False, refs_id,
                    utt_ids, speakers)

                # backward decoder
                nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search(
                    eouts, elens, params, idx2token, lm_bwd, None, lm,
                    scores_ctc, params['recog_beam_width'], False, refs_id,
                    utt_ids, speakers)

                # forward-backward attention
                best_hyps_id = fwd_bwd_attention(
                    nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd,
                    aws_bwd, scores_bwd, self.eos,
                    params['recog_gnmt_decoding'],
                    params['recog_length_penalty'], idx2token, refs_id)
                nbest_hyps_id = [[hyp] for hyp in best_hyps_id]
                aws = None
            else:
                # ensemble
                ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], []
                if len(ensemble_models) > 0:
                    for i_e, model in enumerate(ensemble_models):
                        enc_outs_e = model.encode(xs, task)
                        ensmbl_eouts += [enc_outs_e[task]['xs']]
                        ensmbl_elens += [enc_outs_e[task]['xlens']]
                        ensmbl_decs += [getattr(model, 'dec_' + dir)]
                        # NOTE: only support for the main task now

                lm = getattr(self, 'lm_' + dir, None)
                lm_second = getattr(self, 'lm_second', None)
                lm_bwd = getattr(self, 'lm_bwd' if dir == 'fwd' else 'lm_bwd',
                                 None)

                nbest_hyps_id, aws, scores = getattr(
                    self, 'dec_' + dir).beam_search(
                        eouts, elens, params, idx2token, lm, lm_second, lm_bwd,
                        scores_ctc, params['recog_beam_width'], exclude_eos,
                        refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens,
                        ensmbl_decs)
                # print(nbest_hyps_id)
                # assert False

        return nbest_hyps_id, aws