def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 8, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, output_beam_size: int = 8, ): assert check_argument_types() # 1. Build ASR model asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() token_list = asr_model.token_list self.decode_graph = k2.arc_sort( build_ctc_topo(list(range(len(token_list))))).to(device) if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") logging.info(f"Running on : {device}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.device = device self.dtype = dtype self.output_beam_size = output_beam_size
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", batch_size: int = 1, dtype: str = "float32", maskctc_n_iterations: int = 10, maskctc_threshold_probability: float = 0.99, ): assert check_argument_types() # 1. Build ASR model asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() token_list = asr_model.token_list s2t = MaskCTCInference( asr_model=asr_model, n_iterations=maskctc_n_iterations, threshold_probability=maskctc_threshold_probability, ) s2t.to(device=device, dtype=getattr(torch, dtype)).eval() # 2. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.s2t = s2t self.converter = converter self.tokenizer = tokenizer self.device = device self.dtype = dtype
def extract_zip_model_file(self, zip_model_file: str) -> Dict[str, Any]: """Extrai os dados de um zip contendo o arquivo com o estado do modelo e configurações Args: zip_model_file (str): ZipFile do modelo gerado dos scripts de treinamento Raises: ValueError: Se o arquivo não for correto FileNotFoundError: Se o arquivo zip não contiver os arquivos necessários Returns: Dict[str, Any]: Dicionário do arquivo .yaml utilizado durante o treinamento para carregar o modelo corretamente """ print("Unzipping model") if not zipfile.is_zipfile(zip_model_file): raise ValueError(f"File {zip_model_file} is not a zipfile") else: zipfile.ZipFile(zip_model_file).extractall(dirname(zip_model_file)) check = ['exp', 'meta.yaml'] if not all([x for x in check]): raise FileNotFoundError print("Load yaml file") with open('meta.yaml') as f: meta = yaml.load(f, Loader=yaml.FullLoader) model_stats_file = meta['files']['asr_model_file'] asr_model_config_file = meta['yaml_files']['asr_train_config'] self.model_config = {} with open(asr_model_config_file) as f: self.model_config = yaml.load(f, Loader=yaml.FullLoader) try: self.global_cmvn = self.model_config['normalize_conf'][ 'stats_file'] except KeyError: self.global_cmvn = None print(f'Loading model config from {asr_model_config_file}') print(f'Loading model state from {model_stats_file}') #Build Model print('Building model') self.model, _ = ASRTask.build_model_from_file(asr_model_config_file, model_stats_file, self.device) self.model.to(dtype=getattr(torch, 'float32')).eval() #print("Loading extra modules") self.build_beam_search() self.build_tokenizer()
def inference( output_dir: str, maxlenratio: float, minlenratio: float, batch_size: int, dtype: str, beam_size: int, ngpu: int, seed: int, ctc_weight: float, lm_weight: float, penalty: float, nbest: int, num_workers: int, log_level: Union[int, str], data_path_and_name_and_type: Sequence[Tuple[str, str, str]], key_file: Optional[str], asr_train_config: str, asr_model_file: str, lm_train_config: Optional[str], lm_file: Optional[str], word_lm_train_config: Optional[str], word_lm_file: Optional[str], blank_symbol: str, token_type: Optional[str], bpemodel: Optional[str], allow_variable_data_keys: bool, ): assert check_argument_types() if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") if ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if ngpu >= 1: device = "cuda" else: device = "cpu" # 1. Set random-seed set_all_random_seed(seed) # 2. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.eval() decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 3. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) scorers["lm"] = lm.lm # 4. Build BeamSearch object weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, ) beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 5. Build data-iterator loader = ASRTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, batch_size=batch_size, key_file=key_file, num_workers=num_workers, preprocess_fn=ASRTask.build_preprocess_fn(asr_train_args, False), collate_fn=ASRTask.build_collate_fn(asr_train_args), allow_variable_data_keys=allow_variable_data_keys, inference=True, ) # 6. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about with DatadirWriter(output_dir) as writer: for keys, batch in loader: assert isinstance(batch, dict), type(batch) assert all(isinstance(s, str) for s in keys), keys _bs = len(next(iter(batch.values()))) assert len(keys) == _bs, f"{len(keys)} != {_bs}" with torch.no_grad(): # a. To device batch = to_device(batch, device) # b. Forward Encoder enc, _ = asr_model.encode(**batch) assert len(enc) == batch_size, len(enc) # c. Passed the encoder result and the beam search nbest_hyps = beam_search(x=enc[0], maxlenratio=maxlenratio, minlenratio=minlenratio) nbest_hyps = nbest_hyps[:nbest] # Only supporting batch_size==1 key = keys[0] for n in range(1, nbest + 1): hyp = nbest_hyps[n - 1] assert isinstance(hyp, Hypothesis), type(hyp) # remove sos/eos and get results token_int = hyp.yseq[1:-1].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != 0, token_int)) # Change integer-ids to tokens token = converter.ids2tokens(token_int) # Create a directory: outdir/{n}best_recog ibest_writer = writer[f"{n}best_recog"] # Write the result to each files ibest_writer["token"][key] = " ".join(token) ibest_writer["token_int"][key] = " ".join(map(str, token_int)) ibest_writer["score"][key] = str(hyp.score) if tokenizer is not None: text = tokenizer.tokens2text(token) ibest_writer["text"][key] = text
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, fs: int = 16000, ngpu: int = 0, batch_size: int = 1, dtype: str = "float32", kaldi_style_text: bool = True, text_converter: str = "tokenize", time_stamps: str = "auto", **ctc_segmentation_args, ): """Initialize the CTCSegmentation module. Args: asr_train_config: ASR model config file (yaml). asr_model_file: ASR model file (pth). fs: Sample rate of audio file. ngpu: Number of GPUs. Set 0 for processing on CPU, set to 1 for processing on GPU. Multi-GPU aligning is currently not implemented. Default: 0. batch_size: Currently, only batch size == 1 is implemented. dtype: Data type used for inference. Set dtype according to the ASR model. kaldi_style_text: A kaldi-style text file includes the name of the utterance at the start of the line. If True, the utterance name is expected as first word at each line. If False, utterance names are automatically generated. Set this option according to your input data. Default: True. text_converter: How CTC segmentation handles text. "tokenize": Use ESPnet 2 preprocessing to tokenize the text. "classic": The text is preprocessed as in ESPnet 1 which takes token length into account. If the ASR model has longer tokens, this option may yield better results. Default: "tokenize". time_stamps: Choose the method how the time stamps are calculated. While "fixed" and "auto" use both the sample rate, the ratio of samples to one frame is either automatically determined for each inference or fixed at a certain ratio that is initially determined by the module, but can be changed via the parameter ``samples_to_frames_ratio``. Recommended for longer audio files: "auto". **ctc_segmentation_args: Parameters for CTC segmentation. """ assert check_argument_types() # Basic settings if batch_size > 1: raise NotImplementedError("Batch decoding is not implemented") device = "cpu" if ngpu == 1: device = "cuda" elif ngpu > 1: logging.error("Multi-GPU not yet implemented.") raise NotImplementedError("Only single GPU decoding is supported") # Prepare ASR model asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() self.preprocess_fn = ASRTask.build_preprocess_fn(asr_train_args, False) # Warn for nets with high memory consumption on long audio files if hasattr(asr_model, "encoder"): encoder_module = asr_model.encoder.__class__.__module__ else: encoder_module = "Unknown" logging.info(f"Encoder module: {encoder_module}") logging.info(f"CTC module: {asr_model.ctc.__class__.__module__}") if "rnn" not in encoder_module.lower(): logging.warning( "No RNN model detected; memory consumption may be high.") self.asr_model = asr_model self.asr_train_args = asr_train_args self.device = device self.dtype = dtype self.ctc = asr_model.ctc self.kaldi_style_text = kaldi_style_text self.token_list = asr_model.token_list # Apply configuration self.set_config( fs=fs, time_stamps=time_stamps, kaldi_style_text=kaldi_style_text, text_converter=text_converter, **ctc_segmentation_args, ) # last token "<sos/eos>", not needed self.config.char_list = asr_model.token_list[:-1]
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 20, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, disable_repetition_detection=False, decoder_text_length_limit=0, encoded_feat_length_limit=0, ): assert check_argument_types() # 1. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() assert isinstance(asr_model.encoder, ContextualBlockTransformerEncoder) or isinstance( asr_model.encoder, ContextualBlockConformerEncoder) decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) scorers["lm"] = lm.lm # 3. Build BeamSearch object weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) assert "encoder_conf" in asr_train_args assert "look_ahead" in asr_train_args.encoder_conf assert "hop_size" in asr_train_args.encoder_conf assert "block_size" in asr_train_args.encoder_conf # look_ahead = asr_train_args.encoder_conf['look_ahead'] # hop_size = asr_train_args.encoder_conf['hop_size'] # block_size = asr_train_args.encoder_conf['block_size'] assert batch_size == 1 beam_search = BatchBeamSearchOnline( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", disable_repetition_detection=disable_repetition_detection, decoder_text_length_limit=decoder_text_length_limit, encoded_feat_length_limit=encoded_feat_length_limit, ) non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] assert len(non_batch) == 0 # TODO(karita): make all scorers batchfied logging.info("BatchBeamSearchOnline implementation is selected.") beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest if "n_fft" in asr_train_args.frontend_conf: self.n_fft = asr_train_args.frontend_conf["n_fft"] else: self.n_fft = 512 if "hop_length" in asr_train_args.frontend_conf: self.hop_length = asr_train_args.frontend_conf["hop_length"] else: self.hop_length = 128 if ("win_length" in asr_train_args.frontend_conf and asr_train_args.frontend_conf["win_length"] is not None): self.win_length = asr_train_args.frontend_conf["win_length"] else: self.win_length = self.n_fft self.reset()
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 20, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, ): assert check_argument_types() # 1. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) scorers["lm"] = lm.lm # 3. Build BeamSearch object weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if batch_size == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: if streaming: beam_search.__class__ = BatchBeamSearchOnlineSim beam_search.set_streaming_config(asr_train_config) logging.info( "BatchBeamSearchOnlineSim implementation is selected.") else: beam_search.__class__ = BatchBeamSearch logging.info("BatchBeamSearch implementation is selected.") else: logging.warning(f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation.") beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 8, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, search_beam_size: int = 20, output_beam_size: int = 20, min_active_states: int = 30, max_active_states: int = 10000, blank_bias: float = 0.0, lattice_weight: float = 1.0, is_ctc_decoding: bool = True, lang_dir: Optional[str] = None, use_fgram_rescoring: bool = False, use_nbest_rescoring: bool = False, am_weight: float = 1.0, decoder_weight: float = 0.5, nnlm_weight: float = 1.0, num_paths: int = 1000, nbest_batch_size: int = 500, nll_batch_size: int = 100, ): assert check_argument_types() # 1. Build ASR model asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device ) asr_model.to(dtype=getattr(torch, dtype)).eval() token_list = asr_model.token_list # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) self.lm = lm self.is_ctc_decoding = is_ctc_decoding self.use_fgram_rescoring = use_fgram_rescoring self.use_nbest_rescoring = use_nbest_rescoring assert self.is_ctc_decoding, "Currently, only ctc_decoding graph is supported." if self.is_ctc_decoding: self.decode_graph = k2.arc_sort( build_ctc_topo(list(range(len(token_list)))) ) self.decode_graph = self.decode_graph.to(device) if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") logging.info(f"Running on : {device}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.device = device self.dtype = dtype self.search_beam_size = search_beam_size self.output_beam_size = output_beam_size self.min_active_states = min_active_states self.max_active_states = max_active_states self.blank_bias = blank_bias self.lattice_weight = lattice_weight self.am_weight = am_weight self.decoder_weight = decoder_weight self.nnlm_weight = nnlm_weight self.num_paths = num_paths self.nbest_batch_size = nbest_batch_size self.nll_batch_size = nll_batch_size
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, dtype: str = "float32", beam_size: int = 20, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, ): assert check_argument_types() # 1. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device ) asr_model.eval() decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) scorers["lm"] = lm.lm # 3. Build BeamSearch object weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, ) beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.lm_train_args = lm_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest