def from_pretrained( cls, model_name: str, refresh_cache: bool = False, override_config_path: Optional[str] = None, map_location: Optional['torch.device'] = None, strict: bool = True, return_config: bool = False, ): """ Instantiates an instance of NeMo from NVIDIA NGC cloud Use restore_from() to instantiate from a local .nemo file. Args: model_name: string key which will be used to find the module. refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file from cloud even if it is already found in a cache locally. override_config_path: path to a yaml config that will override the internal config file map_location: Optional torch.device() to map the instantiated model to a device. By default (None), it will select a GPU if available, falling back to CPU otherwise. strict: Passed to torch.load_state_dict return_config: If set to true, will return just the underlying config of the restored model as an OmegaConf DictConfig object without instantiating the model. Returns: A model instance of a particular model class or its underlying config (if return_config is set). """ location_in_the_cloud = None description = None if cls.list_available_models() is not None: for pretrained_model_info in cls.list_available_models(): if pretrained_model_info.pretrained_model_name == model_name: location_in_the_cloud = pretrained_model_info.location description = pretrained_model_info.description class_ = pretrained_model_info.class_ if location_in_the_cloud is None: raise FileNotFoundError( f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models." ) filename = location_in_the_cloud.split("/")[-1] url = location_in_the_cloud.replace(filename, "") cache_dir = Path.joinpath(Path.home(), f'.cache/torch/NeMo/NeMo_{nemo.__version__}/{filename[:-5]}') # If either description and location in the cloud changes, this will force re-download cache_subfolder = hashlib.md5((location_in_the_cloud + description).encode('utf-8')).hexdigest() # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True nemo_model_file_in_cache = maybe_download_from_cloud( url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache ) logging.info("Instantiating model from pre-trained checkpoint") if class_ is None: class_ = cls instance = class_.restore_from( restore_path=nemo_model_file_in_cache, override_config_path=override_config_path, map_location=map_location, strict=strict, return_config=return_config, ) return instance
def _get_ngc_pretrained_model_info(cls, model_name: str, refresh_cache: bool = False ) -> (type, str): """ Resolve the NGC model pretrained information given a model name. Assumes the model subclass implements the `list_available_models()` inherited method. Args: model_name: Str name of the model. Must be the original name or an alias of the model, without any '/'. refresh_cache: Bool, determines whether cache must be refreshed (model is re-downloaded). Returns: A tuple of details describing : - The resolved class of the model. This requires subclass to implement PretrainedModelInfo.class_. If the class cannot be resolved, default to the class that called this method. - The path to the NeMo model (.nemo file) in some cached directory. """ location_in_the_cloud = None description = None class_ = None models = cls.list_available_models() if models is not None: for pretrained_model_info in cls.list_available_models(): found = False if pretrained_model_info.pretrained_model_name == model_name: found = True elif pretrained_model_info.aliases is not None: for alias in pretrained_model_info.aliases: if alias == model_name: found = True break if found: location_in_the_cloud = pretrained_model_info.location description = pretrained_model_info.description class_ = pretrained_model_info.class_ break if location_in_the_cloud is None: raise FileNotFoundError( f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models." ) filename = location_in_the_cloud.split("/")[-1] url = location_in_the_cloud.replace(filename, "") cache_dir = Path.joinpath(model_utils.resolve_cache_dir(), f'{filename[:-5]}') # If either description and location in the cloud changes, this will force re-download cache_subfolder = hashlib.md5( (location_in_the_cloud + description).encode('utf-8')).hexdigest() # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True nemo_model_file_in_cache = maybe_download_from_cloud( url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache) logging.info("Instantiating model from pre-trained checkpoint") if class_ is None: class_ = cls return class_, nemo_model_file_in_cache