コード例 #1
0
 def test_set_dir(self):
     temp_dir = tempfile.gettempdir()
     hub.set_dir(temp_dir)
     hub_model = hub.load('pytorch/vision', 'resnet18', pretrained=True)
     self.assertEqual(self.resnet18_pretrained, hub_model.state_dict())
     assert os.path.exists(temp_dir + '/vision_master')
     shutil.rmtree(temp_dir + '/vision_master')
コード例 #2
0
ファイル: test_utils.py プロジェクト: zkz917/pytorch
 def test_set_dir(self):
     temp_dir = tempfile.gettempdir()
     hub.set_dir(temp_dir)
     hub_model = hub.load('pytorch/vision', 'resnet18', pretrained=True)
     self.assertEqual(sum_of_model_parameters(hub_model),
                      SUM_OF_PRETRAINED_RESNET18_PARAMS)
     assert os.path.exists(temp_dir + '/pytorch_vision_master')
     shutil.rmtree(temp_dir + '/pytorch_vision_master')
コード例 #3
0
ファイル: test_utils.py プロジェクト: yuan50697105/pytorch
 def test_set_dir(self):
     temp_dir = tempfile.gettempdir()
     hub.set_dir(temp_dir)
     hub_model = hub.load('ailzhang/torchhub_example',
                          'mnist',
                          pretrained=True,
                          verbose=False)
     self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
                      SUM_OF_HUB_EXAMPLE)
     assert os.path.exists(temp_dir + '/ailzhang_torchhub_example_master')
     shutil.rmtree(temp_dir + '/ailzhang_torchhub_example_master')
コード例 #4
0
 def test_set_dir(self):
     temp_dir = tempfile.gettempdir()
     hub.set_dir(temp_dir)
     hub_model = hub.load("pytorch/vision",
                          "resnet18",
                          weights="DEFAULT",
                          progress=False)
     assert sum_of_model_parameters(hub_model).item() == pytest.approx(
         SUM_OF_PRETRAINED_RESNET18_PARAMS)
     assert os.path.exists(temp_dir + "/pytorch_vision_master")
     shutil.rmtree(temp_dir + "/pytorch_vision_master")
コード例 #5
0
 def test_set_dir(self):
     temp_dir = tempfile.gettempdir()
     hub.set_dir(temp_dir)
     hub_model = hub.load('pytorch/vision',
                          'resnet18',
                          pretrained=True,
                          progress=False)
     assert sum_of_model_parameters(hub_model).item() == pytest.approx(
         SUM_OF_PRETRAINED_RESNET18_PARAMS)
     assert os.path.exists(temp_dir + '/pytorch_vision_master')
     shutil.rmtree(temp_dir + '/pytorch_vision_master')
コード例 #6
0
ファイル: train.py プロジェクト: joizhang/ffd-attention
import warnings

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch import hub, optim
from torch.backends import cudnn

from config import Config
from training import models
from training.datasets import get_dataloader
from training.tools.train_utils import parse_args, train, validate

CONFIG = Config()
hub.set_dir(CONFIG['TORCH_HOME'])

torch.backends.cudnn.benchmark = True


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
コード例 #7
0
import argparse
from pathlib import Path

from torch import hub
from torchvision import models

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_arch", type=str, help="Architecture of the torchvision model")
    parser.add_argument("torch_hub", type=str, help="Path to set as torch hub")
    args = parser.parse_args()

    hub_path = Path(args.torch_hub)
    hub.set_dir(str(hub_path.resolve()))
    model = getattr(models, args.model_arch)(pretrained=True)
コード例 #8
0
from typing import List, Tuple

from PIL.Image import Image
from torch import hub

import whatplane.models.model_helpers as mh
from whatplane.models.predict_model import predict_image_data

BASE_DIR = Path(".")
MODELS_DIR = BASE_DIR / "models"

with open(BASE_DIR / "api/imagenet_class_index.json") as f:
    imagenet_class_index = json.load(f)

# Set PyTorch cache directory to where the model is stored
hub.set_dir(str(MODELS_DIR.resolve()))
imagenet_model = mh.initialize_model(
    "densenet161",
    [item[1] for item in list(imagenet_class_index.values())],
    replace_classifier=False,
)
whatplane_model = mh.load_model(MODELS_DIR / "model.pth")


def should_predict_whatplane(imagenet_probs: List[float],
                             imagenet_classes: List[str]) -> bool:
    """Function to check which model to use for prediction.

    Args:
        imagenet_probs (List[float]): List of probabilities returned from ImageNet
        imagenet_classes (List[str]): List of class names returned from ImageNet