Beispiel #1
0
def from_hub(repo_id: str, **kwargs: Any):
    """Instantiate & load a pretrained model from HF hub.

    >>> from doctr.models import from_hub
    >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")

    Args:
        repo_id: HuggingFace model hub repo
        kwargs: kwargs of `hf_hub_download` or `snapshot_download`

    Returns:
        Model loaded with the checkpoint
    """

    # Get the config
    with open(hf_hub_download(repo_id, filename="config.json", **kwargs),
              "rb") as f:
        cfg = json.load(f)

    arch = cfg["arch"]
    task = cfg["task"]
    cfg.pop("arch")
    cfg.pop("task")

    if task == "classification":
        model = models.classification.__dict__[arch](
            pretrained=False,
            classes=cfg["classes"],
            num_classes=cfg["num_classes"])
    elif task == "detection":
        model = models.detection.__dict__[arch](pretrained=False)
    elif task == "recognition":
        model = models.recognition.__dict__[arch](
            pretrained=False,
            input_shape=cfg["input_shape"],
            vocab=cfg["vocab"])
    elif task == "obj_detection" and is_torch_available():
        model = models.obj_detection.__dict__[arch](
            pretrained=False,
            image_mean=cfg["mean"],
            image_std=cfg["std"],
            max_size=cfg["input_shape"][-1],
            num_classes=len(cfg["classes"]),
        )

    # update model cfg
    model.cfg = cfg

    # Load checkpoint
    if is_torch_available():
        state_dict = torch.load(hf_hub_download(repo_id,
                                                filename="pytorch_model.bin",
                                                **kwargs),
                                map_location="cpu")
        model.load_state_dict(state_dict)
    else:  # tf
        repo_path = snapshot_download(repo_id, **kwargs)
        model.load_weights(os.path.join(repo_path, "tf_model", "weights"))

    return model
Beispiel #2
0
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str,
                                      task: str) -> None:
    """Save model and config to disk for pushing to huggingface hub

    Args:
        model: TF or PyTorch model to be saved
        save_dir: directory to save model and config
        arch: architecture name
        task: task name
    """
    save_directory = Path(save_dir)

    if is_torch_available():
        weights_path = save_directory / "pytorch_model.bin"
        torch.save(model.state_dict(), weights_path)
    elif is_tf_available():
        weights_path = save_directory / "tf_model" / "weights"
        model.save_weights(str(weights_path))

    config_path = save_directory / "config.json"

    # add model configuration
    model_config = model.cfg
    model_config["arch"] = arch
    model_config["task"] = task

    with config_path.open("w") as f:
        json.dump(model_config, f, indent=2, ensure_ascii=False)
Beispiel #3
0
def test_file_utils():
    assert is_torch_available()
Beispiel #4
0
# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py

import json
import logging
import os
import subprocess
import textwrap
from pathlib import Path
from typing import Any

from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, snapshot_download

from doctr import models
from doctr.file_utils import is_tf_available, is_torch_available

if is_torch_available():
    import torch

__all__ = [
    "login_to_hub", "push_to_hf_hub", "from_hub",
    "_save_model_and_config_for_hf_hub"
]

AVAILABLE_ARCHS = {
    "classification":
    models.classification.zoo.ARCHS,
    "detection":
    models.detection.zoo.ARCHS + models.detection.zoo.ROT_ARCHS,
    "recognition":
    models.recognition.zoo.ARCHS,
    "obj_detection":
Beispiel #5
0
from .. import detection
from ..preprocessor import PreProcessor
from .predictor import DetectionPredictor

__all__ = ["detection_predictor"]

ARCHS: List[str]
ROT_ARCHS: List[str]

if is_tf_available():
    ARCHS = [
        "db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18",
        "linknet_resnet34", "linknet_resnet50"
    ]
    ROT_ARCHS = ["linknet_resnet18_rotation"]
elif is_torch_available():
    ARCHS = [
        "db_resnet34",
        "db_resnet50",
        "db_mobilenet_v3_large",
        "linknet_resnet18",
        "linknet_resnet34",
        "linknet_resnet50",
    ]
    ROT_ARCHS = ["db_resnet50_rotation"]


def _predictor(arch: Any,
               pretrained: bool,
               assume_straight_pages: bool = True,
               **kwargs: Any) -> DetectionPredictor:
Beispiel #6
0
from doctr.file_utils import is_tf_available, is_torch_available

if not is_tf_available() and is_torch_available():
    from .pytorch import *