コード例 #1
0
def load_pretrained_params(
    model: nn.Module,
    url: Optional[str] = None,
    hash_prefix: Optional[str] = None,
    overwrite: bool = False,
    **kwargs: Any,
) -> None:
    """Load a set of parameters onto a model

    Example::
        >>> from doctr.models import load_pretrained_params
        >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")

    Args:
        model: the keras model to be loaded
        url: URL of the zipped set of parameters
        hash_prefix: first characters of SHA256 expected hash
        overwrite: should the zip extraction be enforced if the archive has already been extracted
    """

    if url is None:
        logging.warning("Invalid model URL, using default initialization.")
    else:
        archive_path = download_from_url(url,
                                         hash_prefix=hash_prefix,
                                         cache_subdir='models',
                                         **kwargs)

        # Read state_dict
        state_dict = torch.load(archive_path, map_location='cpu')

        # Load weights
        model.load_state_dict(state_dict)
コード例 #2
0
def load_pretrained_params(
    model: nn.Module,
    url: Optional[str] = None,
    hash_prefix: Optional[str] = None,
    overwrite: bool = False,
    ignore_keys: Optional[List[str]] = None,
    **kwargs: Any,
) -> None:
    """Load a set of parameters onto a model

    >>> from doctr.models import load_pretrained_params
    >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")

    Args:
        model: the PyTorch model to be loaded
        url: URL of the zipped set of parameters
        hash_prefix: first characters of SHA256 expected hash
        overwrite: should the zip extraction be enforced if the archive has already been extracted
        ignore_keys: list of weights to be ignored from the state_dict
    """

    if url is None:
        logging.warning("Invalid model URL, using default initialization.")
    else:
        archive_path = download_from_url(url,
                                         hash_prefix=hash_prefix,
                                         cache_subdir="models",
                                         **kwargs)

        # Read state_dict
        state_dict = torch.load(archive_path, map_location="cpu")

        # Remove weights from the state_dict
        if ignore_keys is not None and len(ignore_keys) > 0:
            for key in ignore_keys:
                state_dict.pop(key)
            missing_keys, unexpected_keys = model.load_state_dict(state_dict,
                                                                  strict=False)
            if set(missing_keys) != set(ignore_keys) or len(
                    unexpected_keys) > 0:
                raise ValueError(
                    "unable to load state_dict, due to non-matching keys.")
        else:
            # Load weights
            model.load_state_dict(state_dict)
コード例 #3
0
ファイル: base.py プロジェクト: mindee/doctr
    def __init__(
        self,
        url: str,
        file_name: Optional[str] = None,
        file_hash: Optional[str] = None,
        extract_archive: bool = False,
        download: bool = False,
        overwrite: bool = False,
        cache_dir: Optional[str] = None,
        cache_subdir: Optional[str] = None,
        **kwargs: Any,
    ) -> None:

        cache_dir = os.path.join(os.path.expanduser("~"), ".cache",
                                 "doctr") if cache_dir is None else cache_dir
        cache_subdir = "datasets" if cache_subdir is None else cache_subdir

        file_name = file_name if isinstance(file_name,
                                            str) else os.path.basename(url)
        # Download the file if not present
        archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir,
                                                      file_name)

        if not os.path.exists(archive_path) and not download:
            raise ValueError(
                "the dataset needs to be downloaded first with download=True")

        archive_path = download_from_url(url,
                                         file_name,
                                         file_hash,
                                         cache_dir=cache_dir,
                                         cache_subdir=cache_subdir)

        # Extract the archive
        if extract_archive:
            archive_path = Path(archive_path)
            dataset_path = archive_path.parent.joinpath(archive_path.stem)
            if not dataset_path.is_dir() or overwrite:
                shutil.unpack_archive(archive_path, dataset_path)

        super().__init__(dataset_path if extract_archive else archive_path,
                         **kwargs)
コード例 #4
0
def load_pretrained_params(
    model: Model,
    url: Optional[str] = None,
    hash_prefix: Optional[str] = None,
    overwrite: bool = False,
    internal_name: str = 'weights',
    **kwargs: Any,
) -> None:
    """Load a set of parameters onto a model

    Example::
        >>> from doctr.models import load_pretrained_params
        >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")

    Args:
        model: the keras model to be loaded
        url: URL of the zipped set of parameters
        hash_prefix: first characters of SHA256 expected hash
        overwrite: should the zip extraction be enforced if the archive has already been extracted
        internal_name: name of the ckpt files
    """

    if url is None:
        logging.warning("Invalid model URL, using default initialization.")
    else:
        archive_path = download_from_url(url,
                                         hash_prefix=hash_prefix,
                                         cache_subdir='models',
                                         **kwargs)

        # Unzip the archive
        params_path = archive_path.parent.joinpath(archive_path.stem)
        if not params_path.is_dir() or overwrite:
            with ZipFile(archive_path, 'r') as f:
                f.extractall(path=params_path)

        # Load weights
        model.load_weights(f"{params_path}{os.sep}{internal_name}")