Ejemplo n.º 1
0
    def from_pretrained(
        cls,
        model_name: str,
        refresh_cache: bool = False,
        override_config_path: Optional[str] = None,
        map_location: Optional['torch.device'] = None,
        strict: bool = True,
        return_config: bool = False,
    ):
        """
        Instantiates an instance of NeMo from NVIDIA NGC cloud
        Use restore_from() to instantiate from a local .nemo file.
        Args:
            model_name: string key which will be used to find the module.
            refresh_cache: If set to True, then when fetching from cloud, this will re-fetch the file
                from cloud even if it is already found in a cache locally.
            override_config_path: path to a yaml config that will override the internal
                config file
            map_location: Optional torch.device() to map the instantiated model to a device.
                By default (None), it will select a GPU if available, falling back to CPU otherwise.
            strict: Passed to torch.load_state_dict
            return_config: If set to true, will return just the underlying config of the restored
                model as an OmegaConf DictConfig object without instantiating the model.

        Returns:
            A model instance of a particular model class or its underlying config (if return_config is set).
        """
        location_in_the_cloud = None
        description = None
        if cls.list_available_models() is not None:
            for pretrained_model_info in cls.list_available_models():
                if pretrained_model_info.pretrained_model_name == model_name:
                    location_in_the_cloud = pretrained_model_info.location
                    description = pretrained_model_info.description
                    class_ = pretrained_model_info.class_
        if location_in_the_cloud is None:
            raise FileNotFoundError(
                f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models."
            )
        filename = location_in_the_cloud.split("/")[-1]
        url = location_in_the_cloud.replace(filename, "")
        cache_dir = Path.joinpath(Path.home(), f'.cache/torch/NeMo/NeMo_{nemo.__version__}/{filename[:-5]}')
        # If either description and location in the cloud changes, this will force re-download
        cache_subfolder = hashlib.md5((location_in_the_cloud + description).encode('utf-8')).hexdigest()
        # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True
        nemo_model_file_in_cache = maybe_download_from_cloud(
            url=url, filename=filename, cache_dir=cache_dir, subfolder=cache_subfolder, refresh_cache=refresh_cache
        )
        logging.info("Instantiating model from pre-trained checkpoint")
        if class_ is None:
            class_ = cls
        instance = class_.restore_from(
            restore_path=nemo_model_file_in_cache,
            override_config_path=override_config_path,
            map_location=map_location,
            strict=strict,
            return_config=return_config,
        )
        return instance
Ejemplo n.º 2
0
    def _get_ngc_pretrained_model_info(cls,
                                       model_name: str,
                                       refresh_cache: bool = False
                                       ) -> (type, str):
        """
        Resolve the NGC model pretrained information given a model name.
        Assumes the model subclass implements the `list_available_models()` inherited method.

        Args:
            model_name: Str name of the model. Must be the original name or an alias of the model, without any '/'.
            refresh_cache: Bool, determines whether cache must be refreshed (model is re-downloaded).

        Returns:
            A tuple of details describing :
            -   The resolved class of the model. This requires subclass to implement PretrainedModelInfo.class_.
                If the class cannot be resolved, default to the class that called this method.
            -   The path to the NeMo model (.nemo file) in some cached directory.
        """
        location_in_the_cloud = None
        description = None
        class_ = None
        models = cls.list_available_models()
        if models is not None:
            for pretrained_model_info in cls.list_available_models():
                found = False
                if pretrained_model_info.pretrained_model_name == model_name:
                    found = True
                elif pretrained_model_info.aliases is not None:
                    for alias in pretrained_model_info.aliases:
                        if alias == model_name:
                            found = True
                            break
                if found:
                    location_in_the_cloud = pretrained_model_info.location
                    description = pretrained_model_info.description
                    class_ = pretrained_model_info.class_
                    break

        if location_in_the_cloud is None:
            raise FileNotFoundError(
                f"Model {model_name} was not found. Check cls.list_available_models() for the list of all available models."
            )
        filename = location_in_the_cloud.split("/")[-1]
        url = location_in_the_cloud.replace(filename, "")
        cache_dir = Path.joinpath(model_utils.resolve_cache_dir(),
                                  f'{filename[:-5]}')
        # If either description and location in the cloud changes, this will force re-download
        cache_subfolder = hashlib.md5(
            (location_in_the_cloud + description).encode('utf-8')).hexdigest()
        # if file exists on cache_folder/subfolder, it will be re-used, unless refresh_cache is True
        nemo_model_file_in_cache = maybe_download_from_cloud(
            url=url,
            filename=filename,
            cache_dir=cache_dir,
            subfolder=cache_subfolder,
            refresh_cache=refresh_cache)

        logging.info("Instantiating model from pre-trained checkpoint")

        if class_ is None:
            class_ = cls

        return class_, nemo_model_file_in_cache