def get_registered_adapter(cls: Union[str, type]) -> Optional[AdapterRegistryInfo]: """ Resolves a provided `cls` (whether str path to class, a registered base or an adapter class) to obtain the metadata for the adapter. Args: cls: Can be a str (absolute path to a class), a base class or an adapter class (which have already been registered). Returns: A AdapterRegistryInfo object if it could resolve successfully, otherwise None. """ global ADAPTER_REGISTRY if isinstance(cls, str): cls = model_utils.import_class_by_path(cls) # If an adapter class was provided, de-reference its base class if hasattr(cls, '_meta_base_class'): cls = cls._meta_base_class class_path = f'{cls.__module__}.{cls.__name__}' # If base class, check registry if class_path in ADAPTER_REGISTRY: return ADAPTER_REGISTRY[class_path] return None
def _get_class_from_path(domain, subdomains, imp): path = _build_import_path(domain, subdomains, imp) class_ = None result = None try: class_ = model_utils.import_class_by_path(path) if inspect.isclass(class_): # Is class wrpped in a wrapt.decorator a the class level? Unwrap for checks. if isinstance(class_, wrapt.FunctionWrapper): class_ = class_.__wrapped__ # Subclass tests if issubclass(class_, (Model, torch.nn.Module)): result = class_ else: class_ = None error = None except Exception: error = traceback.format_exc() return class_, result, error
def initialize_model(model_name): # load model if model_name not in MODEL_CACHE: if '.nemo' in model_name: # use local model model_name_no_ext = os.path.splitext(model_name)[0] model_path = os.path.join('models', model_name_no_ext, model_name) # Extract config model_cfg = nemo_asr.models.ASRModel.restore_from( restore_path=model_path, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path( classpath) # type: ASRModel logging.info(f"Restoring local model : {imported_class.__name__}") # load model from checkpoint model = imported_class.restore_from( restore_path=model_path, map_location='cpu') # type: ASRModel else: # use pretrained model model = nemo_asr.models.ASRModel.from_pretrained( model_name, map_location='cpu') model.freeze() # cache model MODEL_CACHE[model_name] = model model = MODEL_CACHE[model_name] return model
def from_config_dict(cls, config: 'DictConfig'): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if _HAS_HYDRA: if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = maybe_update_config_version(config) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) # Hydra 1.x API elif '_target_' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None imported_cls_tb = None # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config[ "target"] # No guarantee that this is a omegaconf class imported_cls = None try: # try to import the target class imported_cls = import_class_by_path(target_cls) except Exception: imported_cls_tb = traceback.format_exc() # try instantiating model with target class if imported_cls is not None: # if calling class (cls) is subclass of imported class, # use subclass instead if issubclass(cls, imported_cls): imported_cls = cls try: instance = imported_cls(cfg=config) except Exception: imported_cls_tb = traceback.format_exc() instance = None # target class resolution was unsuccessful, fall back to current `cls` if instance is None: if imported_cls_tb is not None: logging.debug( f"Model instantiation from target class {target_cls} failed with following error.\n" f"Falling back to `cls`.\n" f"{imported_cls_tb}") instance = cls(cfg=config) if not hasattr(instance, '_cfg'): instance._cfg = config return instance
def from_config_dict(cls, config: DictConfig): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = maybe_update_config_version(config) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) # Hydra 1.x API elif '_target_' in config: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config.target imported_cls = None try: # try to import the target class imported_cls = import_class_by_path(target_cls) except (ImportError, ModuleNotFoundError): logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls') # try instantiating model with target class if imported_cls is not None: try: instance = imported_cls(cfg=config) except Exception: imported_cls_tb = traceback.format_exc() logging.debug( f"Model instantiation from target class failed with following error.\n" f"Falling back to `cls`.\n" f"{imported_cls_tb}" ) instance = None # target class resolution was unsuccessful, fall back to current `cls` if instance is None: instance = cls(cfg=config) if not hasattr(instance, '_cfg'): instance._cfg = config return instance
def main(cfg: TranscriptionConfig): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None !") # setup gpu if cfg.cuda is None: cfg.cuda = torch.cuda.is_available() if type(cfg.cuda) == int: device_id = int(cfg.cuda) else: device_id = 0 device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu') # setup model if cfg.model_path is not None: # restore model from .nemo file path model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from(restore_path=cfg.model_path, map_location=device) # type: ASRModel else: # restore model by name asr_model = ASRModel.from_pretrained(model_name=cfg.pretrained_name, map_location=device) # type: ASRModel trainer = pl.Trainer(gpus=int(cfg.cuda)) asr_model.set_trainer(trainer) asr_model = asr_model.eval() # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): asr_model.change_decoding_strategy(cfg.rnnt_decoding) # load paths to audio filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}"))) logging.info(f"\nTranscribing {len(filepaths)} files...\n") # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # transcribe audio with autocast(): with torch.no_grad(): transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size) logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") with open(cfg.output_filename, 'w', encoding='utf-8') as f: for line in transcriptions: f.write(f"{line}\n") logging.info("Finished writing predictions !")
def main(cfg: TranscriptionConfig) -> TranscriptionConfig: logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError( "Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError( "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") # setup GPU if cfg.cuda is None: if torch.cuda.is_available(): cfg.cuda = 0 # use 0th CUDA device else: cfg.cuda = -1 # use CPU device = torch.device(f'cuda:{cfg.cuda}' if cfg.cuda >= 0 else 'cpu') # setup model if cfg.model_path is not None: # restore model from .nemo file path model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path( classpath) # type: ASRModel logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( restore_path=cfg.model_path, map_location=device) # type: ASRModel model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] else: # restore model by name asr_model = ASRModel.from_pretrained( model_name=cfg.pretrained_name, map_location=device) # type: ASRModel model_name = cfg.pretrained_name trainer = pl.Trainer(gpus=[cfg.cuda] if cfg.cuda >= 0 else 0) asr_model.set_trainer(trainer) asr_model = asr_model.eval() # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): asr_model.change_decoding_strategy(cfg.rnnt_decoding) # get audio filenames if cfg.audio_dir is not None: filepaths = list( glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}"))) else: # get filenames from manifest filepaths = [] with open(cfg.dataset_manifest, 'r') as f: for line in f: item = json.loads(line) filepaths.append(item['audio_filepath']) logging.info(f"\nTranscribing {len(filepaths)} files...\n") # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr( torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # Compute output filename if cfg.output_filename is None: # create default output filename if cfg.audio_dir is not None: cfg.output_filename = os.path.dirname( os.path.join(cfg.audio_dir, '.')) + '.json' else: cfg.output_filename = cfg.dataset_manifest.replace( '.json', f'_{model_name}.json') # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): logging.info( f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." ) return cfg # transcribe audio with autocast(): with torch.no_grad(): transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size) logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis if type(transcriptions) == tuple and len(transcriptions) == 2: transcriptions = transcriptions[0] # write audio transcriptions with open(cfg.output_filename, 'w', encoding='utf-8') as f: if cfg.audio_dir is not None: for idx, text in enumerate(transcriptions): item = {'audio_filepath': filepaths[idx], 'pred_text': text} f.write(json.dumps(item) + "\n") else: with open(cfg.dataset_manifest, 'r') as fr: for idx, line in enumerate(fr): item = json.loads(line) item['pred_text'] = transcriptions[idx] f.write(json.dumps(item) + "\n") logging.info("Finished writing predictions !") return cfg
def main(): parser = ArgumentParser() parser.add_argument("--model_file", type=str, required=True, help="Path to source .nemo file") parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") parser.add_argument("--tensor_model_parallel_size", type=int, required=True, help="TP size of source model") parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") parser.add_argument( "--model_class", type=str, default= "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel", help= "NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", ) parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") args = parser.parse_args() precision = args.precision if args.precision in ["32", "16"]: precision = int(float(args.precision)) tp_size = args.tensor_model_parallel_size tgt_tp_size = args.target_tensor_model_parallel_size cls = model_utils.import_class_by_path(args.model_class) trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = 1 # not supported yet in this script app_state.tensor_model_parallel_size = tp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size if tp_size > 1: partitions = [] for i in range(tp_size): app_state.tensor_model_parallel_rank = i model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) params = [p for _, p in model.named_parameters()] partitions.append(params) # app_state is being updated incorrectly during restore app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = 1 # not supported yet in this script app_state.tensor_model_parallel_size = tp_size app_state.model_parallel_size = ( app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size) model.cfg.tensor_model_parallel_size = 1 app_state.model_parallel_size = 1 trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) model = cls(model.cfg, trainer).to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() if tgt_tp_size > 1: merge_partition(model, partitions) else: merge_partition(model, partitions, args.target_file) else: app_state.model_parallel_size = 1 model = cls.restore_from(restore_path=args.model_file, trainer=trainer) if tgt_tp_size > 1: partitions = [] params = [p for _, p in model.named_parameters()] partitions.append(params) model.cfg.tensor_model_parallel_size = tgt_tp_size app_state.model_parallel_size = tgt_tp_size trainer = Trainer(devices=1, plugins=NLPDDPPlugin(), accelerator="cpu", precision=precision) model = cls(model.cfg, trainer).to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() split_partition(model, partitions, tgt_tp_size, args.target_file) logging.info("Successfully finished changing partitions!")
def main(cfg: TranscriptionConfig) -> TranscriptionConfig: logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError( "Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError( "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") # setup GPU if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device accelerator = 'gpu' else: device = 1 accelerator = 'cpu' else: device = [cfg.cuda] accelerator = 'gpu' map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') # setup model if cfg.model_path is not None: # restore model from .nemo file path model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path( classpath) # type: ASRModel logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( restore_path=cfg.model_path, map_location=map_location) # type: ASRModel model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] else: # restore model by name asr_model = ASRModel.from_pretrained( model_name=cfg.pretrained_name, map_location=map_location) # type: ASRModel model_name = cfg.pretrained_name trainer = pl.Trainer(devices=device, accelerator=accelerator) asr_model.set_trainer(trainer) asr_model = asr_model.eval() partial_audio = False # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): asr_model.change_decoding_strategy(cfg.rnnt_decoding) # get audio filenames if cfg.audio_dir is not None: filepaths = list( glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}"))) else: # get filenames from manifest filepaths = [] if os.stat(cfg.dataset_manifest).st_size == 0: logging.error( f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!" ) return None with open(cfg.dataset_manifest, 'r') as f: has_two_fields = [] for line in f: item = json.loads(line) if "offset" in item and "duration" in item: has_two_fields.append(True) else: has_two_fields.append(False) filepaths.append(item['audio_filepath']) partial_audio = all(has_two_fields) logging.info(f"\nTranscribing {len(filepaths)} files...\n") # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr( torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # Compute output filename if cfg.output_filename is None: # create default output filename if cfg.audio_dir is not None: cfg.output_filename = os.path.dirname( os.path.join(cfg.audio_dir, '.')) + '.json' else: cfg.output_filename = cfg.dataset_manifest.replace( '.json', f'_{model_name}.json') # if transcripts should not be overwritten, and already exists, skip re-transcription step and return if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): logging.info( f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." ) return cfg # transcribe audio with autocast(): with torch.no_grad(): if partial_audio: if isinstance(asr_model, EncDecCTCModel): transcriptions = transcribe_partial_audio( asr_model=asr_model, path2manifest=cfg.dataset_manifest, batch_size=cfg.batch_size, num_workers=cfg.num_workers, ) else: logging.warning( "RNNT models do not support transcribe partial audio for now. Transcribing full audio." ) transcriptions = asr_model.transcribe( paths2audio_files=filepaths, batch_size=cfg.batch_size, num_workers=cfg.num_workers, ) else: transcriptions = asr_model.transcribe( paths2audio_files=filepaths, batch_size=cfg.batch_size, num_workers=cfg.num_workers, ) logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis if type(transcriptions) == tuple and len(transcriptions) == 2: transcriptions = transcriptions[0] # write audio transcriptions with open(cfg.output_filename, 'w', encoding='utf-8') as f: if cfg.audio_dir is not None: for idx, text in enumerate(transcriptions): item = {'audio_filepath': filepaths[idx], 'pred_text': text} f.write(json.dumps(item) + "\n") else: with open(cfg.dataset_manifest, 'r') as fr: for idx, line in enumerate(fr): item = json.loads(line) item['pred_text'] = transcriptions[idx] f.write(json.dumps(item) + "\n") logging.info("Finished writing predictions !") return cfg
def from_config_dict(cls, config: 'DictConfig', trainer: Optional['Trainer'] = None): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if _HAS_HYDRA: if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = maybe_update_config_version(config) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) # Hydra 1.x API elif '_target_' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None prev_error = "" # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config[ "target"] # No guarantee that this is a omegaconf class imported_cls = None try: # try to import the target class imported_cls = import_class_by_path(target_cls) # if calling class (cls) is subclass of imported class, # use subclass instead if issubclass(cls, imported_cls): imported_cls = cls accepts_trainer = Serialization._inspect_signature_for_trainer( imported_cls) if accepts_trainer: instance = imported_cls(cfg=config, trainer=trainer) else: instance = imported_cls(cfg=config) except Exception as e: # record previous error tb = traceback.format_exc() prev_error = f"Model instantiation failed!\nTarget class:\t{target_cls}" f"\nError(s):\t{e}\n{tb}" logging.debug(prev_error + "\nFalling back to `cls`.") # target class resolution was unsuccessful, fall back to current `cls` if instance is None: try: accepts_trainer = Serialization._inspect_signature_for_trainer( cls) if accepts_trainer: instance = cls(cfg=config, trainer=trainer) else: instance = cls(cfg=config) except Exception as e: # report saved errors, if any, and raise if prev_error: logging.error(prev_error) raise e if not hasattr(instance, '_cfg'): instance._cfg = config return instance
def main(cfg: TranscriptionConfig): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError( "Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: raise ValueError( "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") # setup GPU if cfg.cuda is None: cfg.cuda = torch.cuda.is_available() if type(cfg.cuda) == int: device_id = int(cfg.cuda) else: device_id = 0 device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu') # setup model if cfg.model_path is not None: # restore model from .nemo file path model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path( classpath) # type: ASRModel logging.info(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( restore_path=cfg.model_path, map_location=device) # type: ASRModel model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] else: # restore model by name asr_model = ASRModel.from_pretrained( model_name=cfg.pretrained_name, map_location=device) # type: ASRModel model_name = cfg.pretrained_name trainer = pl.Trainer(gpus=int(cfg.cuda)) asr_model.set_trainer(trainer) asr_model = asr_model.eval() # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): asr_model.change_decoding_strategy(cfg.rnnt_decoding) # get audio filenames if cfg.audio_dir is not None: filepaths = list( glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}"))) else: # get filenames from manifest filepaths = [] with open(cfg.dataset_manifest, 'r') as f: for line in f: item = json.loads(line) filepaths.append(item['audio_filepath']) logging.info(f"\nTranscribing {len(filepaths)} files...\n") # setup AMP (optional) if cfg.amp and torch.cuda.is_available() and hasattr( torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast else: @contextlib.contextmanager def autocast(): yield # transcribe audio with autocast(): with torch.no_grad(): transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size) logging.info(f"Finished transcribing {len(filepaths)} files !") if cfg.output_filename is None: # create default output filename if cfg.audio_dir is not None: cfg.output_filename = os.path.dirname( os.path.join(cfg.audio_dir, '.')) + '.json' else: cfg.output_filename = cfg.dataset_manifest.replace( '.json', f'_{model_name}.json') logging.info(f"Writing transcriptions into file: {cfg.output_filename}") with open(cfg.output_filename, 'w', encoding='utf-8') as f: if cfg.audio_dir is not None: for idx, text in enumerate(transcriptions): item = {'audio_filepath': filepaths[idx], 'pred_text': text} f.write(json.dumps(item) + "\n") else: with open(cfg.dataset_manifest, 'r') as fr: for idx, line in enumerate(fr): item = json.loads(line) item['pred_text'] = transcriptions[idx] f.write(json.dumps(item) + "\n") logging.info("Finished writing predictions !")
def main(): parser = argparse.ArgumentParser() parser.add_argument( 'model_fname_list', metavar='N', type=str, nargs='+', help='Input .nemo files (or folders who contains them) to parse', ) parser.add_argument( '--import_fname_list', type=str, nargs='+', default=[], help= 'A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)', ) args = parser.parse_args() logging.info( f"\n\nIMPORTANT: Use --import_fname_list for all files that contain missing classes (AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\n\n" ) for fn in args.import_fname_list: logging.info(f"Importing * from {fn}") sys.path.insert(0, os.path.dirname(fn)) globals().update( importlib.import_module(os.path.splitext( os.path.basename(fn))[0]).__dict__) device = torch.device("cpu") # loop over all folders with .nemo files (or .nemo files) for model_fname_i, model_fname in enumerate(args.model_fname_list): if not model_fname.endswith(".nemo"): # assume model_fname is a folder which contains a .nemo file (filter .nemo files which matches with "*-averaged.nemo") nemo_files = list( filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo")))) if len(nemo_files) != 1: raise RuntimeError( f"Expected only a single .nemo files but discovered {len(nemo_files)} .nemo files" ) model_fname = nemo_files[0] model_folder_path = os.path.dirname(model_fname) fn, fe = os.path.splitext(model_fname) avg_model_fname = f"{fn}-averaged{fe}" logging.info( f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n" ) # restore model from .nemo file path model_cfg = ModelPT.restore_from(restore_path=model_fname, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path(classpath) logging.info(f"Loading model {model_fname}") nemo_model = imported_class.restore_from(restore_path=model_fname, map_location=device) # search for all checkpoints (ignore -last.ckpt) checkpoint_paths = [ os.path.join(model_folder_path, x) for x in os.listdir(model_folder_path) if x.endswith('.ckpt') and not x.endswith('-last.ckpt') ] """ < Checkpoint Averaging Logic > """ # load state dicts n = len(checkpoint_paths) avg_state = None logging.info(f"Averaging {n} checkpoints ...") for ix, path in enumerate(checkpoint_paths): checkpoint = torch.load(path, map_location=device) if 'state_dict' in checkpoint: checkpoint = checkpoint['state_dict'] if ix == 0: # Initial state avg_state = checkpoint logging.info( f"Initialized average state dict with checkpoint : {path}") else: # Accumulated state for k in avg_state: avg_state[k] = avg_state[k] + checkpoint[k] logging.info( f"Updated average state dict with state from checkpoint : {path}" ) for k in avg_state: if str(avg_state[k].dtype).startswith("torch.int"): # For int type, not averaged, but only accumulated. # e.g. BatchNorm.num_batches_tracked pass else: avg_state[k] = avg_state[k] / n # restore merged weights into model nemo_model.load_state_dict(avg_state, strict=True) # Save model logging.info(f"Saving average mdel to: {avg_model_fname}") nemo_model.save_to(avg_model_fname)
if __name__ == "__main__": args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) text_files = [] if args.in_text: if args.model is None: raise ValueError( f"ASR model must be provided to extract vocabulary for text processing" ) elif os.path.exists(args.model): model_cfg = ASRModel.restore_from(restore_path=args.model, return_config=True) classpath = model_cfg.target # original class path imported_class = model_utils.import_class_by_path( classpath) # type: ASRModel print(f"Restoring model : {imported_class.__name__}") asr_model = imported_class.restore_from( restore_path=args.model) # type: ASRModel model_name = os.path.splitext(os.path.basename(args.model))[0] else: # restore model by name asr_model = ASRModel.from_pretrained( model_name=args.model) # type: ASRModel model_name = args.model vocabulary = asr_model.cfg.decoder.vocabulary if os.path.isdir(args.in_text): text_files = Path(args.in_text).glob(("*.txt")) else: