def recog_v2(args): """Decode with custom models that implements ScorerInterface. Notes: The previous backend espnet.asr.pytorch_backend.asr.recog only supports E2E and RNNLM Args: args (namespace): The program arguments. See py:func:`espnet.bin.asr_recog.get_parser` for details """ logging.warning("experimental API for custom LMs is selected by --api v2") if args.batchsize > 1: raise NotImplementedError("multi-utt batch decoding is not implemented") if args.streaming_mode is not None: raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) if args.quantize_config is not None: q_config = set([getattr(torch.nn, q) for q in args.quantize_config]) else: q_config = {torch.nn.Linear} if args.quantize_asr_model: logging.info("Use quantized asr model for decoding") # See https://github.com/espnet/espnet/pull/3616 for more information. if ( torch.__version__ < LooseVersion("1.4.0") and "lstm" in train_args.etype and torch.nn.LSTM in q_config ): raise ValueError( "Quantized LSTM in ESPnet is only supported with torch 1.4+." ) if args.quantize_dtype == "float16" and torch.__version__ < LooseVersion( "1.5.0" ): raise ValueError( "float16 dtype for dynamic quantization is not supported with torch " "version < 1.5.0. Switching to qint8 dtype instead." ) dtype = getattr(torch, args.quantize_dtype) model = torch.quantization.quantize_dynamic(model, q_config, dtype=dtype) model.eval() load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.rnnlm: lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) # NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) lm = lm_class(len(train_args.char_list), lm_args) torch_load(args.rnnlm, lm) if args.quantize_lm_model: logging.info("Use quantized lm model") dtype = getattr(torch, args.quantize_dtype) lm = torch.quantization.quantize_dynamic(lm, q_config, dtype=dtype) lm.eval() else: lm = None if args.ngram_model: from espnet.nets.scorers.ngram import NgramFullScorer from espnet.nets.scorers.ngram import NgramPartScorer if args.ngram_scorer == "full": ngram = NgramFullScorer(args.ngram_model, train_args.char_list) else: ngram = NgramPartScorer(args.ngram_model, train_args.char_list) else: ngram = None scorers = model.scorers() scorers["lm"] = lm scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) weights = dict( decoder=1.0 - args.ctc_weight, ctc=args.ctc_weight, lm=args.lm_weight, ngram=args.ngram_weight, length_bonus=args.penalty, ) beam_search = BeamSearch( beam_size=args.beam_size, vocab_size=len(train_args.char_list), weights=weights, scorers=scorers, sos=model.sos, eos=model.eos, token_list=train_args.char_list, pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logging.info("BatchBeamSearch implementation is selected.") else: logging.warning( f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation." ) if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") if args.ngpu == 1: device = "cuda" else: device = "cpu" dtype = getattr(torch, args.dtype) logging.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype).eval() beam_search.to(device=device, dtype=dtype).eval() # read json data with open(args.recog_json, "rb") as f: js = json.load(f)["utts"] new_js = {} with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype)) nbest_hyps = beam_search( x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio ) nbest_hyps = [ h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] ] new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )
def test_batch_beam_search_equal( model_class, args, ctc_weight, lm_nn, lm_args, lm_weight, ngram_weight, bonus, device, dtype, ): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("no cuda device is available") if device == "cpu" and dtype == "float16": pytest.skip( "cpu float16 implementation is not available in pytorch yet") # seed setting torch.manual_seed(123) torch.backends.cudnn.deterministic = True # https://github.com/pytorch/pytorch/issues/6351 torch.backends.cudnn.benchmark = False dtype = getattr(torch, dtype) model, x, ilens, y, data, train_args = prepare(model_class, args, mtlalpha=ctc_weight) model.eval() char_list = train_args.char_list lm = dynamic_import_lm(lm_nn, backend="pytorch")(len(char_list), lm_args) lm.eval() root = os.path.dirname(os.path.abspath(__file__)) ngram = NgramFullScorer(os.path.join(root, "beam_search_test.arpa"), args.char_list) # test previous beam search args = Namespace( beam_size=3, penalty=bonus, ctc_weight=ctc_weight, maxlenratio=0, lm_weight=lm_weight, ngram_weight=ngram_weight, minlenratio=0, nbest=5, ) # new beam search scorers = model.scorers() if lm_weight != 0: scorers["lm"] = lm if ngram_weight != 0: scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(char_list)) weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=args.lm_weight, ngram=args.ngram_weight, length_bonus=args.penalty, ) model.to(device, dtype=dtype) model.eval() with torch.no_grad(): enc = model.encode(x[0, :ilens[0]].to(device, dtype=dtype)) legacy_beam = BeamSearch( beam_size=args.beam_size, vocab_size=len(char_list), weights=weights, scorers=scorers, token_list=train_args.char_list, sos=model.sos, eos=model.eos, pre_beam_score_key=None if ctc_weight == 1.0 else "decoder", ) legacy_beam.to(device, dtype=dtype) legacy_beam.eval() beam = BatchBeamSearch( beam_size=args.beam_size, vocab_size=len(char_list), weights=weights, scorers=scorers, token_list=train_args.char_list, sos=model.sos, eos=model.eos, ) beam.to(device, dtype=dtype) beam.eval() with torch.no_grad(): legacy_nbest_bs = legacy_beam(x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio) nbest_bs = beam(x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio) for i, (expected, actual) in enumerate(zip(legacy_nbest_bs, nbest_bs)): assert expected.yseq.tolist() == actual.yseq.tolist() numpy.testing.assert_allclose(expected.score.cpu(), actual.score.cpu(), rtol=1e-6)
def __init__( self, asr_train_config: Union[Path, str] = None, asr_model_file: Union[Path, str] = None, transducer_conf: dict = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, ngram_scorer: str = "full", ngram_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 20, ctc_weight: float = 0.5, lm_weight: float = 1.0, ngram_weight: float = 0.9, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, enh_s2t_task: bool = False, quantize_asr_model: bool = False, quantize_lm: bool = False, quantize_modules: List[str] = ["Linear"], quantize_dtype: str = "qint8", ): assert check_argument_types() task = ASRTask if not enh_s2t_task else EnhS2TTask if quantize_asr_model or quantize_lm: if quantize_dtype == "float16" and torch.__version__ < LooseVersion( "1.5.0"): raise ValueError( "float16 dtype for dynamic quantization is not supported with " "torch version < 1.5.0. Switch to qint8 dtype instead.") quantize_modules = set( [getattr(torch.nn, q) for q in quantize_modules]) quantize_dtype = getattr(torch, quantize_dtype) # 1. Build ASR model scorers = {} asr_model, asr_train_args = task.build_model_from_file( asr_train_config, asr_model_file, device) if enh_s2t_task: asr_model.inherite_attributes(inherite_s2t_attrs=[ "ctc", "decoder", "eos", "joint_network", "sos", "token_list", "use_transducer_decoder", ]) asr_model.to(dtype=getattr(torch, dtype)).eval() if quantize_asr_model: logging.info("Use quantized asr model for decoding.") asr_model = torch.quantization.quantize_dynamic( asr_model, qconfig_spec=quantize_modules, dtype=quantize_dtype) decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) if quantize_lm: logging.info("Use quantized lm for decoding.") lm = torch.quantization.quantize_dynamic( lm, qconfig_spec=quantize_modules, dtype=quantize_dtype) scorers["lm"] = lm.lm # 3. Build ngram model if ngram_file is not None: if ngram_scorer == "full": from espnet.nets.scorers.ngram import NgramFullScorer ngram = NgramFullScorer(ngram_file, token_list) else: from espnet.nets.scorers.ngram import NgramPartScorer ngram = NgramPartScorer(ngram_file, token_list) else: ngram = None scorers["ngram"] = ngram # 4. Build BeamSearch object if asr_model.use_transducer_decoder: beam_search_transducer = BeamSearchTransducer( decoder=asr_model.decoder, joint_network=asr_model.joint_network, beam_size=beam_size, lm=scorers["lm"] if "lm" in scorers else None, lm_weight=lm_weight, **transducer_conf, ) beam_search = None else: beam_search_transducer = None weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, ngram=ngram_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if batch_size == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: if streaming: beam_search.__class__ = BatchBeamSearchOnlineSim beam_search.set_streaming_config(asr_train_config) logging.info( "BatchBeamSearchOnlineSim implementation is selected." ) else: beam_search.__class__ = BatchBeamSearch logging.info( "BatchBeamSearch implementation is selected.") else: logging.warning( f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation.") beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.beam_search_transducer = beam_search_transducer self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest
def recog_v2(args): """Decode with custom models that implements ScorerInterface. Notes: The previous backend espnet.asr.pytorch_backend.asr.recog only supports E2E and RNNLM Args: args (namespace): The program arguments. See py:func:`espnet.bin.asr_recog.get_parser` for details """ logging.warning("experimental API for custom LMs is selected by --api v2") if args.batchsize > 1: raise NotImplementedError( "multi-utt batch decoding is not implemented") if args.streaming_mode is not None: raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.eval() load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.rnnlm: lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) # NOTE: for a compatibility with less than 0.5.0 version models lm_model_module = getattr(lm_args, "model_module", "default") lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) lm = lm_class(len(train_args.char_list), lm_args) torch_load(args.rnnlm, lm) lm.eval() else: lm = None if args.ngram_model: from espnet.nets.scorers.ngram import NgramFullScorer from espnet.nets.scorers.ngram import NgramPartScorer if args.ngram_scorer == "full": ngram = NgramFullScorer(args.ngram_model, train_args.char_list) else: ngram = NgramPartScorer(args.ngram_model, train_args.char_list) else: ngram = None scorers = model.scorers() scorers["lm"] = lm scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) weights = dict( decoder=1.0 - args.ctc_weight, ctc=args.ctc_weight, lm=args.lm_weight, ngram=args.ngram_weight, length_bonus=args.penalty, ) beam_search = BeamSearch( beam_size=args.beam_size, vocab_size=len(train_args.char_list), weights=weights, scorers=scorers, sos=model.sos, eos=model.eos, token_list=train_args.char_list, pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logging.info("BatchBeamSearch implementation is selected.") else: logging.warning(f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation.") if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") if args.ngpu == 1: device = "cuda" else: device = "cpu" dtype = getattr(torch, args.dtype) logging.info(f"Decoding device={device}, dtype={dtype}") model.to(device=device, dtype=dtype).eval() beam_search.to(device=device, dtype=dtype).eval() # read json data with open(args.recog_json, "r") as f: # "rb" content = f.read() if content.startswith( "Warning! You haven't set Python environment yet. Go to /content/espnet/tools and generate 'activate_python.sh'" ): train_json = json.loads( content[110:] )["utts"] # 110 is the number of characters for the above WARNING LINE. else: train_json = json.loads(content)["utts"] # json.load(f)["utts"] js = train_json # json.load(f)["utts"] new_js = {} with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] enc = model.encode( torch.as_tensor(feat).to(device=device, dtype=dtype)) nbest_hyps = beam_search(x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio) nbest_hyps = [ h.asdict() for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)] ] new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) with open(args.result_label, "wb") as f: f.write( json.dumps({ "utts": new_js }, indent=4, ensure_ascii=False, sort_keys=True).encode("utf_8"))
def __init__( self, asr_train_config: Union[Path, str] = None, asr_model_file: Union[Path, str] = None, transducer_conf: dict = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, ngram_scorer: str = "full", ngram_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 20, ctc_weight: float = 0.5, lm_weight: float = 1.0, ngram_weight: float = 0.9, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, ): assert check_argument_types() # 1. Build ASR model scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() decoder = asr_model.decoder ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) token_list = asr_model.token_list scorers.update( decoder=decoder, ctc=ctc, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device) scorers["lm"] = lm.lm # 3. Build ngram model if ngram_file is not None: if ngram_scorer == "full": from espnet.nets.scorers.ngram import NgramFullScorer ngram = NgramFullScorer(ngram_file, token_list) else: from espnet.nets.scorers.ngram import NgramPartScorer ngram = NgramPartScorer(ngram_file, token_list) else: ngram = None scorers["ngram"] = ngram # 4. Build BeamSearch object if asr_model.use_transducer_decoder: beam_search_transducer = BeamSearchTransducer( decoder=asr_model.decoder, joint_network=asr_model.joint_network, beam_size=beam_size, lm=scorers["lm"] if "lm" in scorers else None, lm_weight=lm_weight, **transducer_conf, ) beam_search = None else: beam_search_transducer = None weights = dict( decoder=1.0 - ctc_weight, ctc=ctc_weight, lm=lm_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=asr_model.sos, eos=asr_model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key=None if ctc_weight == 1.0 else "full", ) # TODO(karita): make all scorers batchfied if batch_size == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: if streaming: beam_search.__class__ = BatchBeamSearchOnlineSim beam_search.set_streaming_config(asr_train_config) logging.info( "BatchBeamSearchOnlineSim implementation is selected." ) else: beam_search.__class__ = BatchBeamSearch logging.info( "BatchBeamSearch implementation is selected.") else: logging.warning( f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation.") beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.beam_search_transducer = beam_search_transducer self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest
def __init__( self, st_train_config: Union[Path, str] = None, st_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, ngram_scorer: str = "full", ngram_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 20, lm_weight: float = 1.0, ngram_weight: float = 0.9, penalty: float = 0.0, nbest: int = 1, enh_s2t_task: bool = False, ): assert check_argument_types() task = STTask if not enh_s2t_task else EnhS2TTask # 1. Build ST model scorers = {} st_model, st_train_args = task.build_model_from_file( st_train_config, st_model_file, device ) if enh_s2t_task: st_model.inherite_attributes( inherite_s2t_attrs=[ "ctc", "decoder", "eos", "joint_network", "sos", "token_list", "use_transducer_decoder", ] ) st_model.to(dtype=getattr(torch, dtype)).eval() decoder = st_model.decoder token_list = st_model.token_list scorers.update( decoder=decoder, length_bonus=LengthBonus(len(token_list)), ) # 2. Build Language model if lm_train_config is not None: lm, lm_train_args = LMTask.build_model_from_file( lm_train_config, lm_file, device ) scorers["lm"] = lm.lm # 3. Build ngram model if ngram_file is not None: if ngram_scorer == "full": from espnet.nets.scorers.ngram import NgramFullScorer ngram = NgramFullScorer(ngram_file, token_list) else: from espnet.nets.scorers.ngram import NgramPartScorer ngram = NgramPartScorer(ngram_file, token_list) else: ngram = None scorers["ngram"] = ngram # 4. Build BeamSearch object weights = dict( decoder=1.0, lm=lm_weight, ngram=ngram_weight, length_bonus=penalty, ) beam_search = BeamSearch( beam_size=beam_size, weights=weights, scorers=scorers, sos=st_model.sos, eos=st_model.eos, vocab_size=len(token_list), token_list=token_list, pre_beam_score_key="full", ) # TODO(karita): make all scorers batchfied if batch_size == 1: non_batch = [ k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logging.info("BatchBeamSearch implementation is selected.") else: logging.warning( f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation." ) beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() for scorer in scorers.values(): if isinstance(scorer, torch.nn.Module): scorer.to(device=device, dtype=getattr(torch, dtype)).eval() logging.info(f"Beam_search: {beam_search}") logging.info(f"Decoding device={device}, dtype={dtype}") # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text if token_type is None: token_type = st_train_args.token_type if bpemodel is None: bpemodel = st_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") self.st_model = st_model self.st_train_args = st_train_args self.converter = converter self.tokenizer = tokenizer self.beam_search = beam_search self.maxlenratio = maxlenratio self.minlenratio = minlenratio self.device = device self.dtype = dtype self.nbest = nbest