def __init__(self, conffile=None): if conffile is not None: if isinstance(conffile, dict): self.conf = copy.deepcopy(conffile) else: with io.open(conffile, encoding="utf-8") as f: self.conf = yaml.safe_load(f) assert isinstance(self.conf, dict), type(self.conf) else: self.conf = {"mode": "sequential", "process": []} self.functions = OrderedDict() if self.conf.get("mode", "sequential") == "sequential": for idx, process in enumerate(self.conf["process"]): assert isinstance(process, dict), type(process) opts = dict(process) process_type = opts.pop("type") class_obj = dynamic_import(process_type, import_alias) # TODO(karita): assert issubclass(class_obj, TransformInterface) try: self.functions[idx] = class_obj(**opts) except TypeError: try: signa = signature(class_obj) except ValueError: # Some function, e.g. built-in function, are failed pass else: logging.error("Expected signature: {}({})".format( class_obj.__name__, signa)) raise else: raise NotImplementedError("Not supporting mode={}".format( self.conf["mode"]))
def get_trained_model_state_dict(model_path): """Extract the trained model state dict for pre-initialization. Args: model_path (str): Path to model.***.best Return: model.state_dict() (OrderedDict): the loaded model state_dict (bool): Boolean defining whether the model is an LM """ conf_path = os.path.join(os.path.dirname(model_path), "model.json") if "rnnlm" in model_path: logging.warning("reading model parameters from %s", model_path) return torch.load(model_path), True idim, odim, args = get_model_conf(model_path, conf_path) logging.warning("reading model parameters from " + model_path) if hasattr(args, "model_module"): model_module = args.model_module else: model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" model_class = dynamic_import(model_module) model = model_class(idim, odim, args) torch_load(model_path, model) assert ( isinstance(model, ASRInterface) ) return model.state_dict(), False
def load_trained_model(model_path, training=True): """Load the trained model for recognition. Args: model_path (str): Path to model.***.best """ idim, odim, train_args = get_model_conf( model_path, os.path.join(os.path.dirname(model_path), "model.json") ) logging.warning("reading model parameters from " + model_path) if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" # CTC Loss is not needed, default to builtin to prevent import errors if hasattr(train_args, "ctc_type"): train_args.ctc_type = "builtin" model_class = dynamic_import(model_module) if "transducer" in model_module: model = model_class(idim, odim, train_args, training) else: model = model_class(idim, odim, train_args) torch_load(model_path, model) return model, train_args
def dynamic_import_lm(module, backend): """Import LM class dynamically. Args: module (str): module_name:class_name or alias in `predefined_lms` backend (str): NN backend. e.g., pytorch, chainer Returns: type: LM class """ model_class = dynamic_import(module, predefined_lms.get(backend, dict())) assert issubclass(model_class, LMInterface), f"{module} does not implement LMInterface" return model_class
def load_trained_model(model_path): """Load the trained model for recognition. Args: model_path (str): Path to model.***.best """ idim, odim, train_args = get_model_conf( model_path, os.path.join(os.path.dirname(model_path), "model.json")) logging.warning("reading model parameters from " + model_path) if hasattr(train_args, "model_module"): model_module = train_args.model_module else: model_module = "tt.model:Transducer" model_class = dynamic_import(model_module) model = model_class(idim, odim, train_args) torch_load(model_path, model)
def load_trained_modules(idim, odim, args, interface=ASRInterface): """Load model encoder or/and decoder modules with ESPNET pre-trained model(s). Args: idim (int): initial input dimension. odim (int): initial output dimension. args (Namespace): The initial model arguments. interface (Interface): ASRInterface or STInterface or TTSInterface. Return: model (torch.nn.Module): The model with pretrained modules. """ def print_new_keys(state_dict, modules, model_path): logging.warning("loading %s from model: %s", modules, model_path) for k in state_dict.keys(): logging.warning("override %s" % k) enc_model_path = args.enc_init dec_model_path = args.dec_init enc_modules = args.enc_init_mods dec_modules = args.dec_init_mods model_class = dynamic_import(args.model_module) main_model = model_class(idim, odim, args) assert isinstance(main_model, interface) main_state_dict = main_model.state_dict() logging.warning("model(s) found for pre-initialization") for model_path, modules in [ (enc_model_path, enc_modules), (dec_model_path, dec_modules), ]: if model_path is not None: if os.path.isfile(model_path): model_state_dict, is_lm = get_trained_model_state_dict(model_path) modules = filter_modules(model_state_dict, modules) if is_lm: partial_state_dict, modules = get_partial_lm_state_dict( model_state_dict, modules ) print_new_keys(partial_state_dict, modules, model_path) else: partial_state_dict = get_partial_state_dict( model_state_dict, modules ) if partial_state_dict: if transfer_verification( main_state_dict, partial_state_dict, modules ): print_new_keys(partial_state_dict, modules, model_path) main_state_dict.update(partial_state_dict) else: logging.warning( f"modules {modules} in model {model_path} " f"don't match your training config", ) else: logging.warning("model was not found : %s", model_path) main_model.load_state_dict(main_state_dict) return main_model