def load_pretrained_params( model: nn.Module, url: Optional[str] = None, hash_prefix: Optional[str] = None, overwrite: bool = False, **kwargs: Any, ) -> None: """Load a set of parameters onto a model Example:: >>> from doctr.models import load_pretrained_params >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") Args: model: the keras model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash overwrite: should the zip extraction be enforced if the archive has already been extracted """ if url is None: logging.warning("Invalid model URL, using default initialization.") else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir='models', **kwargs) # Read state_dict state_dict = torch.load(archive_path, map_location='cpu') # Load weights model.load_state_dict(state_dict)
def load_pretrained_params( model: nn.Module, url: Optional[str] = None, hash_prefix: Optional[str] = None, overwrite: bool = False, ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Load a set of parameters onto a model >>> from doctr.models import load_pretrained_params >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") Args: model: the PyTorch model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash overwrite: should the zip extraction be enforced if the archive has already been extracted ignore_keys: list of weights to be ignored from the state_dict """ if url is None: logging.warning("Invalid model URL, using default initialization.") else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) # Read state_dict state_dict = torch.load(archive_path, map_location="cpu") # Remove weights from the state_dict if ignore_keys is not None and len(ignore_keys) > 0: for key in ignore_keys: state_dict.pop(key) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if set(missing_keys) != set(ignore_keys) or len( unexpected_keys) > 0: raise ValueError( "unable to load state_dict, due to non-matching keys.") else: # Load weights model.load_state_dict(state_dict)
def __init__( self, url: str, file_name: Optional[str] = None, file_hash: Optional[str] = None, extract_archive: bool = False, download: bool = False, overwrite: bool = False, cache_dir: Optional[str] = None, cache_subdir: Optional[str] = None, **kwargs: Any, ) -> None: cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "doctr") if cache_dir is None else cache_dir cache_subdir = "datasets" if cache_subdir is None else cache_subdir file_name = file_name if isinstance(file_name, str) else os.path.basename(url) # Download the file if not present archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name) if not os.path.exists(archive_path) and not download: raise ValueError( "the dataset needs to be downloaded first with download=True") archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir) # Extract the archive if extract_archive: archive_path = Path(archive_path) dataset_path = archive_path.parent.joinpath(archive_path.stem) if not dataset_path.is_dir() or overwrite: shutil.unpack_archive(archive_path, dataset_path) super().__init__(dataset_path if extract_archive else archive_path, **kwargs)
def load_pretrained_params( model: Model, url: Optional[str] = None, hash_prefix: Optional[str] = None, overwrite: bool = False, internal_name: str = 'weights', **kwargs: Any, ) -> None: """Load a set of parameters onto a model Example:: >>> from doctr.models import load_pretrained_params >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") Args: model: the keras model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash overwrite: should the zip extraction be enforced if the archive has already been extracted internal_name: name of the ckpt files """ if url is None: logging.warning("Invalid model URL, using default initialization.") else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir='models', **kwargs) # Unzip the archive params_path = archive_path.parent.joinpath(archive_path.stem) if not params_path.is_dir() or overwrite: with ZipFile(archive_path, 'r') as f: f.extractall(path=params_path) # Load weights model.load_weights(f"{params_path}{os.sep}{internal_name}")