def from_hub(repo_id: str, **kwargs: Any): """Instantiate & load a pretrained model from HF hub. >>> from doctr.models import from_hub >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") Args: repo_id: HuggingFace model hub repo kwargs: kwargs of `hf_hub_download` or `snapshot_download` Returns: Model loaded with the checkpoint """ # Get the config with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f: cfg = json.load(f) arch = cfg["arch"] task = cfg["task"] cfg.pop("arch") cfg.pop("task") if task == "classification": model = models.classification.__dict__[arch]( pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"]) elif task == "detection": model = models.detection.__dict__[arch](pretrained=False) elif task == "recognition": model = models.recognition.__dict__[arch]( pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) elif task == "obj_detection" and is_torch_available(): model = models.obj_detection.__dict__[arch]( pretrained=False, image_mean=cfg["mean"], image_std=cfg["std"], max_size=cfg["input_shape"][-1], num_classes=len(cfg["classes"]), ) # update model cfg model.cfg = cfg # Load checkpoint if is_torch_available(): state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") model.load_state_dict(state_dict) else: # tf repo_path = snapshot_download(repo_id, **kwargs) model.load_weights(os.path.join(repo_path, "tf_model", "weights")) return model
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None: """Save model and config to disk for pushing to huggingface hub Args: model: TF or PyTorch model to be saved save_dir: directory to save model and config arch: architecture name task: task name """ save_directory = Path(save_dir) if is_torch_available(): weights_path = save_directory / "pytorch_model.bin" torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model" / "weights" model.save_weights(str(weights_path)) config_path = save_directory / "config.json" # add model configuration model_config = model.cfg model_config["arch"] = arch model_config["task"] = task with config_path.open("w") as f: json.dump(model_config, f, indent=2, ensure_ascii=False)
def test_file_utils(): assert is_torch_available()
# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py import json import logging import os import subprocess import textwrap from pathlib import Path from typing import Any from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download from doctr import models from doctr.file_utils import is_tf_available, is_torch_available if is_torch_available(): import torch __all__ = [ "login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub" ] AVAILABLE_ARCHS = { "classification": models.classification.zoo.ARCHS, "detection": models.detection.zoo.ARCHS + models.detection.zoo.ROT_ARCHS, "recognition": models.recognition.zoo.ARCHS, "obj_detection":
from .. import detection from ..preprocessor import PreProcessor from .predictor import DetectionPredictor __all__ = ["detection_predictor"] ARCHS: List[str] ROT_ARCHS: List[str] if is_tf_available(): ARCHS = [ "db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50" ] ROT_ARCHS = ["linknet_resnet18_rotation"] elif is_torch_available(): ARCHS = [ "db_resnet34", "db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", ] ROT_ARCHS = ["db_resnet50_rotation"] def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
from doctr.file_utils import is_tf_available, is_torch_available if not is_tf_available() and is_torch_available(): from .pytorch import *