def decode(self, xs, params, idx2token, exclude_eos=False, refs_id=None, refs=None, utt_ids=None, speakers=None, task='ys', ensemble_models=[]): """Decoding in the inference stage. Args: xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]` params (dict): hyper-parameters for decoding beam_width (int): the size of beam min_len_ratio (float): max_len_ratio (float): len_penalty (float): length penalty cov_penalty (float): coverage penalty cov_threshold (float): threshold for coverage penalty lm_weight (float): the weight of RNNLM score resolving_unk (bool): not used (to make compatible) fwd_bwd_attention (bool): idx2token (): converter from index to token exclude_eos (bool): exclude <eos> from best_hyps_id refs_id (list): gold token IDs to compute log likelihood refs (list): gold transcriptions utt_ids (list): speakers (list): task (str): ys* or ys_sub1* or ys_sub2* ensemble_models (list): list of Speech2Text classes Returns: best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]` aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]` """ if task.split('.')[0] == 'ys': dir = 'bwd' if self.bwd_weight > 0 and params[ 'recog_bwd_attention'] else 'fwd' elif task.split('.')[0] == 'ys_sub1': dir = 'fwd_sub1' elif task.split('.')[0] == 'ys_sub2': dir = 'fwd_sub2' else: raise ValueError(task) self.eval() with torch.no_grad(): # Encode input features if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir: eout_dict = self.encode(xs, task) else: eout_dict = self.encode(xs, task) # CTC if (self.fwd_weight == 0 and self.bwd_weight == 0) or ( self.ctc_weight > 0 and params['recog_ctc_weight'] == 1): lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_second_bwd = None # TODO best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc( eout_dict[task]['xs'], eout_dict[task]['xlens'], params, idx2token, lm, lm_second, lm_second_bwd, 1, refs_id, utt_ids, speakers) return best_hyps_id, None # Attention elif params['recog_beam_width'] == 1 and not params[ 'recog_fwd_bwd_attention']: best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy( eout_dict[task]['xs'], eout_dict[task]['xlens'], params['recog_max_len_ratio'], idx2token, exclude_eos, refs_id, utt_ids, speakers) else: assert params['recog_batch_size'] == 1 ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = self.dec_fwd.ctc_log_probs( eout_dict[task]['xs']) # forward-backward decoding if params['recog_fwd_bwd_attention']: lm_fwd = getattr(self, 'lm_fwd', None) lm_bwd = getattr(self, 'lm_bwd', None) # forward decoder nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search( eout_dict[task]['xs'], eout_dict[task]['xlens'], params, idx2token, lm_fwd, None, lm_bwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # backward decoder nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search( eout_dict[task]['xs'], eout_dict[task]['xlens'], params, idx2token, lm_bwd, None, lm_fwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # forward-backward attention best_hyps_id = fwd_bwd_attention( nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd, aws_bwd, scores_bwd, self.eos, params['recog_gnmt_decoding'], params['recog_length_penalty'], idx2token, refs_id) aws = None else: # ensemble ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir: enc_outs_e = model.encode(xs, task) else: enc_outs_e = model.encode(xs, task) ensmbl_eouts += [enc_outs_e[task]['xs']] ensmbl_elens += [enc_outs_e[task]['xlens']] ensmbl_decs += [getattr(model, 'dec_' + dir)] # NOTE: only support for the main task now lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_bwd = getattr(self, 'lm_bwd' if dir == 'fwd' else 'lm_bwd', None) nbest_hyps_id, aws, scores = getattr( self, 'dec_' + dir).beam_search( eout_dict[task]['xs'], eout_dict[task]['xlens'], params, idx2token, lm, lm_second, lm_bwd, ctc_log_probs, 1, exclude_eos, refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs) best_hyps_id = [hyp[0] for hyp in nbest_hyps_id] return best_hyps_id, aws
def decode(self, xs, params, idx2token, exclude_eos=False, refs_id=None, refs=None, utt_ids=None, speakers=None, task='ys', ensemble_models=[], trigger_points=None, teacher_force=False): """Decode in the inference stage. Args: xs (List): length `[B]`, which contains arrays of size `[T, input_dim]` params (dict): hyper-parameters for decoding idx2token (): converter from index to token exclude_eos (bool): exclude <eos> from best_hyps_id refs_id (List): gold token IDs to compute log likelihood refs (List): gold transcriptions utt_ids (List): utterance id list speakers (List): speaker list task (str): ys* or ys_sub1* or ys_sub2* ensemble_models (List): Speech2Text classes trigger_points (np.ndarray): `[B, L]` teacher_force (bool): conduct teacher-forcing Returns: nbest_hyps_id (List[List[np.ndarray]]): length `[B]`, which contains a list of length `[n_best]` which contains arrays of size `[L]` aws (List[np.ndarray]): length `[B]`, which contains arrays of size `[L, T, n_heads]` """ if task.split('.')[0] == 'ys': dir = 'bwd' if self.bwd_weight > 0 and params[ 'recog_bwd_attention'] else 'fwd' elif task.split('.')[0] == 'ys_sub1': dir = 'fwd_sub1' elif task.split('.')[0] == 'ys_sub2': dir = 'fwd_sub2' else: raise ValueError(task) if utt_ids is not None: if self.utt_id_prev != utt_ids[0]: self.reset_session() self.utt_id_prev = utt_ids[0] self.eval() with torch.no_grad(): # Encode input features if params['recog_streaming_encoding']: eouts, elens = self.encode_streaming(xs, params, task) else: eout_dict = self.encode(xs, task) eouts = eout_dict[task]['xs'] elens = eout_dict[task]['xlens'] # CTC if (self.fwd_weight == 0 and self.bwd_weight == 0) or ( self.ctc_weight > 0 and params['recog_ctc_weight'] == 1): lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_second_bwd = None # TODO if params.get('recog_beam_width') == 1: nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.greedy( eouts, elens) else: nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.beam_search( eouts, elens, params, idx2token, lm, lm_second, lm_second_bwd, 1, refs_id, utt_ids, speakers) return nbest_hyps_id, None # Attention/RNN-T elif params['recog_beam_width'] == 1 and not params[ 'recog_fwd_bwd_attention']: best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy( eouts, elens, params['recog_max_len_ratio'], idx2token, exclude_eos, refs_id, utt_ids, speakers) nbest_hyps_id = [[hyp] for hyp in best_hyps_id] else: assert params['recog_batch_size'] == 1 scores_ctc = None if params['recog_ctc_weight'] > 0: scores_ctc = self.dec_fwd.ctc.scores(eouts) # forward-backward decoding if params['recog_fwd_bwd_attention']: lm = getattr(self, 'lm_fwd', None) lm_bwd = getattr(self, 'lm_bwd', None) # forward decoder nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search( eouts, elens, params, idx2token, lm, None, lm_bwd, scores_ctc, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # backward decoder nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search( eouts, elens, params, idx2token, lm_bwd, None, lm, scores_ctc, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # forward-backward attention best_hyps_id = fwd_bwd_attention( nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd, aws_bwd, scores_bwd, self.eos, params['recog_gnmt_decoding'], params['recog_length_penalty'], idx2token, refs_id) nbest_hyps_id = [[hyp] for hyp in best_hyps_id] aws = None else: # ensemble ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): enc_outs_e = model.encode(xs, task) ensmbl_eouts += [enc_outs_e[task]['xs']] ensmbl_elens += [enc_outs_e[task]['xlens']] ensmbl_decs += [getattr(model, 'dec_' + dir)] # NOTE: only support for the main task now lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_bwd = getattr(self, 'lm_bwd' if dir == 'fwd' else 'lm_bwd', None) nbest_hyps_id, aws, scores = getattr( self, 'dec_' + dir).beam_search( eouts, elens, params, idx2token, lm, lm_second, lm_bwd, scores_ctc, params['recog_beam_width'], exclude_eos, refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs) return nbest_hyps_id, aws
def decode(self, xs, params, idx2token, nbest=1, exclude_eos=False, refs_id=None, refs_text=None, utt_ids=None, speakers=None, task='ys', ensemble_models=[]): """Decoding in the inference stage. Args: xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]` params (dict): hyper-parameters for decoding beam_width (int): the size of beam min_len_ratio (float): max_len_ratio (float): len_penalty (float): length penalty cov_penalty (float): coverage penalty cov_threshold (float): threshold for coverage penalty lm_weight (float): the weight of RNNLM score resolving_unk (bool): not used (to make compatible) fwd_bwd_attention (bool): idx2token (): converter from index to token nbest (int): exclude_eos (bool): exclude <eos> from best_hyps_id refs_id (list): gold token IDs to compute log likelihood refs_text (list): gold transcriptions utt_ids (list): speakers (list): task (str): ys* or ys_sub1* or ys_sub2* ensemble_models (list): list of Speech2Text classes Returns: best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]` aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]` """ self.eval() with torch.no_grad(): if task.split('.')[0] == 'ys': dir = 'bwd' if self.bwd_weight > 0 and params[ 'recog_bwd_attention'] else 'fwd' elif task.split('.')[0] == 'ys_sub1': dir = 'fwd_sub1' elif task.split('.')[0] == 'ys_sub2': dir = 'fwd_sub2' else: raise ValueError(task) # Encode input features if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir: enc_outs = self.encode(xs, task, flip=True) else: enc_outs = self.encode(xs, task, flip=False) ######################### # CTC ######################### if (self.fwd_weight == 0 and self.bwd_weight == 0) or ( self.ctc_weight > 0 and params['recog_ctc_weight'] == 1): lm = None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm = getattr(self, 'lm_' + dir) best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc( enc_outs[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm, nbest, refs_id, utt_ids, speakers) return best_hyps_id, None, (None, None) ######################### # Attention ######################### else: cache_info = (None, None) if params['recog_beam_width'] == 1 and not params[ 'recog_fwd_bwd_attention']: best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy( enc_outs[task]['xs'], enc_outs[task]['xlens'], params['recog_max_len_ratio'], idx2token, exclude_eos, refs_id, speakers, params['recog_oracle']) else: assert params['recog_batch_size'] == 1 ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = self.dec_fwd.ctc_log_probs( enc_outs[task]['xs']) # forward-backward decoding if params['recog_fwd_bwd_attention']: # forward decoder lm_fwd, lm_bwd = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm_fwd = self.lm_fwd if params['recog_reverse_lm_rescoring'] and hasattr( self, 'lm_bwd') and self.lm_bwd is not None: lm_bwd = self.lm_bwd # ensemble (forward) ensmbl_eouts_fwd = [] ensmbl_elens_fwd = [] ensmbl_decs_fwd = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): enc_outs_e_fwd = model.encode(xs, task, flip=False) ensmbl_eouts_fwd += [ enc_outs_e_fwd[task]['xs'] ] ensmbl_elens_fwd += [ enc_outs_e_fwd[task]['xlens'] ] ensmbl_decs_fwd += [model.dec_fwd] # NOTE: only support for the main task now nbest_hyps_id_fwd, aws_fwd, scores_fwd, cache_info = self.dec_fwd.beam_search( enc_outs[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm_fwd, lm_bwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers, ensmbl_eouts_fwd, ensmbl_elens_fwd, ensmbl_decs_fwd) # backward decoder lm_bwd, lm_fwd = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_bwd') and self.lm_bwd is not None: lm_bwd = self.lm_bwd if params['recog_reverse_lm_rescoring'] and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm_fwd = self.lm_fwd # ensemble (backward) ensmbl_eouts_bwd = [] ensmbl_elens_bwd = [] ensmbl_decs_bwd = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): if self.input_type == 'speech' and self.mtl_per_batch: enc_outs_e_bwd = model.encode(xs, task, flip=True) else: enc_outs_e_bwd = model.encode(xs, task, flip=False) ensmbl_eouts_bwd += [ enc_outs_e_bwd[task]['xs'] ] ensmbl_elens_bwd += [ enc_outs_e_bwd[task]['xlens'] ] ensmbl_decs_bwd += [model.dec_bwd] # NOTE: only support for the main task now # TODO(hirofumi): merge with the forward for the efficiency flip = False if self.input_type == 'speech' and self.mtl_per_batch: flip = True enc_outs_bwd = self.encode(xs, task, flip=True) else: enc_outs_bwd = enc_outs nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search( enc_outs_bwd[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm_bwd, lm_fwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers, ensmbl_eouts_bwd, ensmbl_elens_bwd, ensmbl_decs_bwd) # forward-backward attention best_hyps_id = fwd_bwd_attention( nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd, aws_bwd, scores_bwd, flip, self.eos, params['recog_gnmt_decoding'], params['recog_length_penalty'], idx2token, refs_id) aws = None else: # ensemble ensmbl_eouts = [] ensmbl_elens = [] ensmbl_decs = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir: enc_outs_e = model.encode(xs, task, flip=True) else: enc_outs_e = model.encode(xs, task, flip=False) ensmbl_eouts += [enc_outs_e[task]['xs']] ensmbl_elens += [enc_outs_e[task]['xlens']] ensmbl_decs += [getattr(model, 'dec_' + dir)] # NOTE: only support for the main task now lm, lm_rev = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_' + dir) and getattr( self, 'lm_' + dir) is not None: lm = getattr(self, 'lm_' + dir) if params['recog_reverse_lm_rescoring']: if dir == 'fwd': lm_rev = self.lm_bwd else: raise NotImplementedError nbest_hyps_id, aws, scores, cache_info = getattr( self, 'dec_' + dir).beam_search( enc_outs[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm, lm_rev, ctc_log_probs, nbest, exclude_eos, refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs) if nbest == 1: best_hyps_id = [hyp[0] for hyp in nbest_hyps_id] aws = [aw[0] for aw in aws] if aws is not None else aws else: return nbest_hyps_id, aws, scores, cache_info # NOTE: nbest >= 2 is used for MWER training only return best_hyps_id, aws, cache_info
def decode(self, xs, params, idx2token, exclude_eos=False, refs_id=None, refs=None, utt_ids=None, speakers=None, task='ys', ensemble_models=[], trigger_points=None, teacher_force=False): """Decode in the inference stage. Args: xs (List): length `[B]`, which contains arrays of size `[T, input_dim]` params (dict): hyper-parameters for decoding idx2token (): converter from index to token exclude_eos (bool): exclude <eos> from best_hyps_id refs_id (List): gold token IDs to compute log likelihood refs (List): gold transcriptions utt_ids (List): utterance id list speakers (List): speaker list task (str): ys* or ys_sub1* or ys_sub2* ensemble_models (List): Speech2Text classes trigger_points (np.ndarray): `[B, L]` teacher_force (bool): conduct teacher-forcing Returns: nbest_hyps_id (List[List[np.ndarray]]): length `[B]`, which contains a list of length `[n_best]` which contains arrays of size `[L]` aws (List[np.ndarray]): length `[B]`, which contains arrays of size `[L, T, n_heads]` """ self.eval() if task.split('.')[0] == 'ys': dir = 'bwd' if self.bwd_weight > 0 and params[ 'recog_bwd_attention'] else 'fwd' elif task.split('.')[0] == 'ys_sub1': dir = 'fwd_sub1' elif task.split('.')[0] == 'ys_sub2': dir = 'fwd_sub2' else: raise ValueError(task) if utt_ids is not None: if self.utt_id_prev != utt_ids[0]: self.reset_session() self.utt_id_prev = utt_ids[0] # Encode input features if params['recog_streaming_encoding']: eouts, elens = self.encode_streaming(xs, params, task) else: eout_dict = self.encode(xs, task) eouts = eout_dict[task]['xs'] elens = eout_dict[task]['xlens'] # CTC if (self.fwd_weight == 0 and self.bwd_weight == 0) or ( self.ctc_weight > 0 and params['recog_ctc_weight'] == 1): lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_second_bwd = None # TODO if params.get('recog_beam_width') == 1: nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.greedy(eouts, elens) else: nbest_hyps_id = getattr(self, 'dec_' + dir).ctc.beam_search( eouts, elens, params, idx2token, lm, lm_second, lm_second_bwd, 1, refs_id, utt_ids, speakers) return nbest_hyps_id, None # Attention/RNN-T elif params['recog_beam_width'] == 1 and not params[ 'recog_fwd_bwd_attention']: best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy( eouts, elens, params['recog_max_len_ratio'], idx2token, exclude_eos, refs_id, utt_ids, speakers) nbest_hyps_id = [[hyp] for hyp in best_hyps_id] elif self.is_wfst: # TODO: config # print(eouts.shape) # assert False nbest_hyps_id = [] bs = eouts.size(0) # nbest_hyps_id = Parallel(n_jobs=5)( # delayed(self._wfst)(eouts[b].unsqueeze(0), dir) for b in range(bs) # ) # aws = None # print(nbest_hyps_id) # assert False # for signel processing for b in range(bs): encode_out = eouts[b].unsqueeze(0) initial_packed_states = (0, ) inference_one_step = getattr(self, 'dec_' + dir).decode_wfst_onestep self.decoder.decode(encode_out, initial_packed_states, inference_one_step) words_prediction_id = self.decoder.get_best_path() words_prediction = ''.join( [self.words[int(idx)] for idx in words_prediction_id]) predictions = [ self.vocab_wfst[prediction] for prediction in words_prediction ] # print(words_prediction_id) # print(words_prediction) # print(predictions) nbest_hyps_id.append([np.array(predictions)]) # print(nbest_hyps_id) aws = None # assert False, 'check wfst decode' else: assert params['recog_batch_size'] == 1 # print('okkk') scores_ctc = None if params['recog_ctc_weight'] > 0: scores_ctc = self.dec_fwd.ctc.scores(eouts) # forward-backward decoding if params['recog_fwd_bwd_attention']: lm = getattr(self, 'lm_fwd', None) lm_bwd = getattr(self, 'lm_bwd', None) # forward decoder nbest_hyps_id_fwd, aws_fwd, scores_fwd = self.dec_fwd.beam_search( eouts, elens, params, idx2token, lm, None, lm_bwd, scores_ctc, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # backward decoder nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search( eouts, elens, params, idx2token, lm_bwd, None, lm, scores_ctc, params['recog_beam_width'], False, refs_id, utt_ids, speakers) # forward-backward attention best_hyps_id = fwd_bwd_attention( nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd, aws_bwd, scores_bwd, self.eos, params['recog_gnmt_decoding'], params['recog_length_penalty'], idx2token, refs_id) nbest_hyps_id = [[hyp] for hyp in best_hyps_id] aws = None else: # ensemble ensmbl_eouts, ensmbl_elens, ensmbl_decs = [], [], [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): enc_outs_e = model.encode(xs, task) ensmbl_eouts += [enc_outs_e[task]['xs']] ensmbl_elens += [enc_outs_e[task]['xlens']] ensmbl_decs += [getattr(model, 'dec_' + dir)] # NOTE: only support for the main task now lm = getattr(self, 'lm_' + dir, None) lm_second = getattr(self, 'lm_second', None) lm_bwd = getattr(self, 'lm_bwd' if dir == 'fwd' else 'lm_bwd', None) nbest_hyps_id, aws, scores = getattr( self, 'dec_' + dir).beam_search( eouts, elens, params, idx2token, lm, lm_second, lm_bwd, scores_ctc, params['recog_beam_width'], exclude_eos, refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs) # print(nbest_hyps_id) # assert False return nbest_hyps_id, aws