예제 #1
0
def download_models(translation_config):
    # Step 1: Form the key
    language_direction = translation_config["language_direction"].lower()
    dataset_name = translation_config["dataset_name"].lower()
    key = f"{dataset_name}_{language_direction}"

    # Step 2: Check whether this model already exists
    model_name = f"{key}.pth"
    model_path = os.path.join(BINARIES_PATH, model_name)
    if os.path.exists(model_path):
        print(
            f"No need to download, found model {model_path} that was trained on {dataset_name} for language direction {language_direction}."
        )
        return model_path

    # Step 3: Download the resource to local filesystem
    remote_resource_path = DOWNLOAD_DICT[key]
    if remote_resource_path is None:  # handle models which I've not provided URLs for yet
        print(
            f"No model found that was trained on {dataset_name} for language direction {language_direction}."
        )
        exit(0)

    print(f"Downloading from {remote_resource_path}. This may take a while.")
    download_url_to_file(remote_resource_path, model_path)

    return model_path
예제 #2
0
 def test_download_url_to_file(self):
     temp_file = os.path.join(tempfile.gettempdir(), 'temp')
     hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL,
                              temp_file,
                              progress=False)
     loaded_state = torch.load(temp_file)
     self.assertEqual(sum_of_state_dict(loaded_state), SUM_OF_HUB_EXAMPLE)
예제 #3
0
def get_pretrained(model="9L-L-CoDA-SQ-100000", dataset="Imagenet"):
    """
    Loading the pretrained models for evaluation.
    Returns:
       trainer object with pretrained model
    """
    assert dataset in _model_urls and model in _model_urls[
        dataset], "URL for this model is not specified."
    url, epoch = _model_urls[dataset][model]
    # This model_path convention allows to identify the experiment just from the path and simplifies reloading.
    # (E.g., for the interpretability analysis scripts).
    model_path = os.path.join(BASE_PATH, dataset, "final", model)
    model_file = os.path.join(model_path, "model_epoch_{}.pkl".format(epoch))
    if not os.path.exists(model_file):
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        download_url_to_file(url, model_file, progress=True)
    # Specify model parameters defined in Imagenet.final.experiment_parameters
    exp_params = exps[model]
    # Load model according to experiment parameters
    model = get_model(exp_params)
    # Load data according to experiment parameters
    data_handler = Data(dataset, only_test_loader=True, **exp_params)
    trainer = Trainer(
        model,
        data_handler,
        model_path,  # Setting the path in the trainer to be able to use trainer.reload.
        **exp_params)

    trainer.reload()
    trainer.model.cuda()
    return trainer
예제 #4
0
    def mini_download():
        """Downloads MiniLibriMix from Zenodo in current directory

        Returns:
            The path to the metadata directory.
        """
        mini_dir = "./MiniLibriMix/"
        os.makedirs(mini_dir, exist_ok=True)
        # Download zip (or cached)
        zip_path = mini_dir + "MiniLibriMix.zip"
        if not os.path.isfile(zip_path):
            hub.download_url_to_file(MINI_URL, zip_path)
        # Unzip zip
        cond = all([
            os.path.isdir("MiniLibriMix/" + f)
            for f in ["train", "val", "metadata"]
        ])
        if not cond:
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall("./")  # Will unzip in MiniLibriMix
        # Reorder metadata
        src = "MiniLibriMix/metadata/"
        for mode in ["train", "val"]:
            dst = f"MiniLibriMix/metadata/{mode}/"
            os.makedirs(dst, exist_ok=True)
            [
                shutil.copyfile(src + f, dst + f) for f in os.listdir(src)
                if mode in f and os.path.isfile(src + f)
            ]
        return "./MiniLibriMix/metadata"
예제 #5
0
def load_file_from_url(url,
                       model_dir=None,
                       progress=True,
                       check_hash=False,
                       file_name=None):
    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return cached_file
예제 #6
0
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Load pretrained model
        if self.pretrain_model is not None:
            self.model = torch.load(self.pretrain_model)['model']

        # Load pretained model from url
        else:
            # Check cache directory
            cache_dir = osp.expanduser("~/.cache/torch/checkpoints/")
            if not osp.exists(cache_dir):
                os.makedirs(cache_dir)

            # Cache file name
            cache_file = osp.basename(YOLOv5.PRETRAIN_URL).split("?")[0]
            if not osp.exists(osp.join(cache_dir, cache_file)):
                download_url_to_file(YOLOv5.PRETRAIN_URL,
                                     osp.join(cache_dir, cache_file))

            self.model = torch.load(osp.join(cache_dir, cache_file))['model']

        self.model.to(self.device)
        self.model.eval()

        if "cuda" in self.device:
            self.model = self.model.half()
예제 #7
0
def cached_download(filename_or_url):
    """ Download from URL with torch.hub and cache the result in ASTEROID_CACHE.

    Args:
        filename_or_url (str): Name of a model as named on the Zenodo Community
            page (ex: mpariente/ConvTasNet_WHAM!_sepclean), or an URL to a model
            file (ex: https://zenodo.org/.../model.pth), or a filename
            that exists locally (ex: local/tmp_model.pth)

    Returns:
        str, normalized path to the downloaded (or not) model
    """
    if os.path.isfile(filename_or_url):
        return filename_or_url

    if filename_or_url in MODELS_URLS_HASHTABLE:
        url = MODELS_URLS_HASHTABLE[filename_or_url]
    else:
        # Give a chance to direct URL, torch.hub will handle exceptions
        url = filename_or_url
    cached_filename = url_to_filename(url)
    cached_dir = os.path.join(get_cache_dir(), cached_filename)
    cached_path = os.path.join(cached_dir, "model.pth")

    os.makedirs(cached_dir, exist_ok=True)
    if not os.path.isfile(cached_path):
        hub.download_url_to_file(url, cached_path)
        return cached_path
    # It was already downloaded
    print(f"Using cached model `{filename_or_url}`")
    return cached_path
예제 #8
0
    def mini_download():
        """ Downloads MiniLibriMix from Zenodo in current directory

        Returns:
            The path to the metadata directory.
        """
        mini_dir = './MiniLibriMix/'
        os.makedirs(mini_dir, exist_ok=True)
        # Download zip (or cached)
        zip_path = mini_dir + 'MiniLibriMix.zip'
        if not os.path.isfile(zip_path):
            hub.download_url_to_file(MINI_URL, zip_path)
        # Unzip zip
        cond = all([
            os.path.isdir('MiniLibriMix/' + f)
            for f in ['train', 'val', 'metadata']
        ])
        if not cond:
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall('./')  # Will unzip in MiniLibriMix
        # Reorder metadata
        src = 'MiniLibriMix/metadata/'
        for mode in ['train', 'val']:
            dst = f'MiniLibriMix/metadata/{mode}/'
            os.makedirs(dst, exist_ok=True)
            [
                shutil.copyfile(src + f, dst + f) for f in os.listdir(src)
                if mode in f and os.path.isfile(src + f)
            ]
        return './MiniLibriMix/metadata'
예제 #9
0
    def download(cls, out_dir, sample_rate=16000):
        os.makedirs(out_dir, exist_ok=True)
        exists_cond = all([
            os.path.isdir(os.path.join(out_dir, 'data')),
            os.path.isfile(os.path.join(out_dir, 'train_data.csv')),
            os.path.isfile(os.path.join(out_dir, 'test_data.csv'))
        ])
        if exists_cond:
            print('Dataset seems to be already downloaded and extracted')
            return

        zip_path = os.path.join(out_dir, 'timit.zip')
        hub.download_url_to_file(cls.download_url, zip_path)

        print('Extracting files...')
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(out_dir)

        os.remove(zip_path)
        print('Done')

        if sample_rate != cls.default_sample_rate:
            print(
                f'Resampling from original sample rate {cls.default_sample_rate} Hz to {sample_rate} Hz'
            )
            timit_train = cls(out_dir, subset='train', sample_rate=sample_rate)
            timit_test = cls(out_dir, subset='test', sample_rate=sample_rate)
            timit_train.save_back()
            timit_test.save_back()
예제 #10
0
def download_model_weights(model_name, filename):
    model_to_url = {
        'efficientdet-d0':
        'https://github.com/sevakon/efficientdet/releases/download/2.0/efficientdet-d0.pth',
        'efficientdet-d1':
        'https://github.com/sevakon/efficientdet/releases/download/2.0/efficientdet-d1.pth',
        'efficientdet-d2':
        'https://github.com/sevakon/efficientdet/releases/download/2.0/efficientdet-d2.pth',
        'efficientdet-d3':
        'https://github.com/sevakon/efficientdet/releases/download/v1.0/efficientdet-d3.pth',
        'efficientdet-d4':
        'https://github.com/sevakon/efficientdet/releases/download/v1.0/efficientdet-d4.pth',
        'efficientdet-d5':
        'https://github.com/sevakon/efficientdet/releases/download/v1.0/efficientdet-d5.pth',
        'efficientnet-b0':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
        'efficientnet-b1':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
        'efficientnet-b2':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
        'efficientnet-b3':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
        'efficientnet-b4':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
        'efficientnet-b5':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
        'efficientnet-b6':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
        'efficientnet-b7':
        'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
    }
    download_url_to_file(model_to_url[model_name], filename)
예제 #11
0
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Load file form http url, will download models if necessary.

    Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py

    Args:
        url (str): URL to be downloaded.
        model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
            Default: None.
        progress (bool): Whether to show the download progress. Default: True.
        file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.

    Returns:
        str: The path to the downloaded file.
    """
    if model_dir is None:  # use the pytorch hub_dir
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    os.makedirs(model_dir, exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
    return cached_file
예제 #12
0
def download_data(url, save_as_file=None):
    if os.path.exists(save_as_file):
        return
    else:
        try:
            download_url_to_file(url, save_as_file)
        except:
            print("download failed")
예제 #13
0
def cached_download(filename_or_url):
    """Download from URL and cache the result in ASTEROID_CACHE.

    Args:
        filename_or_url (str): Name of a model as named on the Zenodo Community
            page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or model id from
            the Hugging Face model hub (ex: ``"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"``),
            or a URL to a model file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename
            that exists locally (ex: ``"local/tmp_model.pth"``)

    Returns:
        str, normalized path to the downloaded (or not) model
    """
    from .. import __version__ as asteroid_version  # Avoid circular imports

    if os.path.isfile(filename_or_url):
        return filename_or_url

    if filename_or_url.startswith(huggingface_hub.HUGGINGFACE_CO_URL_HOME):
        filename_or_url = filename_or_url[len(huggingface_hub.
                                              HUGGINGFACE_CO_URL_HOME):]

    if filename_or_url.startswith(("http://", "https://")):
        url = filename_or_url
    elif filename_or_url in MODELS_URLS_HASHTABLE:
        url = MODELS_URLS_HASHTABLE[filename_or_url]
    else:
        # Finally, let's try to find it on Hugging Face model hub
        # e.g. julien-c/DPRNNTasNet-ks16_WHAM_sepclean is a valid model id
        # and  julien-c/DPRNNTasNet-ks16_WHAM_sepclean@main supports specifying a commit/branch/tag.
        if "@" in filename_or_url:
            model_id = filename_or_url.split("@")[0]
            revision = filename_or_url.split("@")[1]
        else:
            model_id = filename_or_url
            revision = None
        url = huggingface_hub.hf_hub_url(
            model_id,
            filename=huggingface_hub.PYTORCH_WEIGHTS_NAME,
            revision=revision)
        return huggingface_hub.cached_download(
            url,
            cache_dir=get_cache_dir(),
            library_name="asteroid",
            library_version=asteroid_version,
        )
    cached_filename = url_to_filename(url)
    cached_dir = os.path.join(get_cache_dir(), cached_filename)
    cached_path = os.path.join(cached_dir, "model.pth")

    os.makedirs(cached_dir, exist_ok=True)
    if not os.path.isfile(cached_path):
        hub.download_url_to_file(url, cached_path)
        return cached_path
    # It was already downloaded
    print(f"Using cached model `{filename_or_url}`")
    return cached_path
예제 #14
0
파일: test_hub.py 프로젝트: yuguo68/pytorch
 def test_download_url_to_file(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         f = os.path.join(tmpdir, 'temp')
         hub.download_url_to_file(TORCHHUB_EXAMPLE_RELEASE_URL,
                                  f,
                                  progress=False)
         loaded_state = torch.load(f)
         self.assertEqual(sum_of_state_dict(loaded_state),
                          SUM_OF_HUB_EXAMPLE)
예제 #15
0
def download_url(url, download_dir=None, filename=None):
    parts = urlparse(url)
    if download_dir is None:
        hub_dir = get_dir()
        download_dir = os.path.join(hub_dir, 'checkpoints')
    if filename is None:
        filename = os.path.basename(parts.path)
    cached_file = os.path.join(download_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        download_url_to_file(url, cached_file)
    return cached_file
예제 #16
0
def download_cached_file(url, check_hash=True, progress=False):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(get_cache_dir(), filename)
    if not os.path.exists(cached_file):
        _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return cached_file
예제 #17
0
def load_custom_pretrained(model,
                           cfg=None,
                           load_fn=None,
                           progress=False,
                           check_hash=False):
    r"""Loads a custom (read non .pth) weight file

    Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
    a passed in custom load fun, or the `load_pretrained` model member fn.

    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        model: The instantiated model to load weights into
        cfg (dict): Default pretrained model cfg
        load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
            'laod_pretrained' on the model will be called if it exists
        progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file. Default: False
    """
    cfg = cfg or getattr(model, 'default_cfg')
    if cfg is None or not cfg.get('url', None):
        _logger.warning(
            "No pretrained weights exist for this model. Using random initialization."
        )
        return
    url = cfg['url']

    model_dir = get_cache_dir()
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    if load_fn is not None:
        load_fn(model, cached_file)
    elif hasattr(model, 'load_pretrained'):
        model.load_pretrained(cached_file)
    else:
        _logger.warning(
            "Valid function to load pretrained weights is not available, using random initialization."
        )
예제 #18
0
def prepare_celeba(celeba_path: str, img_size: int):
    r"""Downloading and preparing CelebA dataset.

    Args:
        -celeba_path (str): Where to store dataset.
        -img_size (int): Size of singleton image in the dataset.
    """

    celeba_url = r"https://s3.amazonaws.com/video.udacity-data.com/topher/2018/November/5be7eb6f_processed-celeba-small/processed-celeba-small.zip"

    # download dataset
    print("downloading CelebA...")
    resource_tmp_path = celeba_path + ".zip"
    download_url_to_file(celeba_url, resource_tmp_path)
    print("finished")

    # unzipping downloaded dataset
    print("unzipping...")
    with zipfile.ZipFile(resource_tmp_path) as zf:
        os.makedirs(celeba_path, exist_ok=True)
        zf.extractall(path=celeba_path)
    print("finished")

    # removing temporary files
    os.remove(resource_tmp_path)
    print("removed temporary file")

    # preparing dataset
    print("preparing CelebA...")
    shutil.rmtree(os.path.join(celeba_path, "__MACOSX"))
    dst_data_directory = os.path.join(celeba_path, "processed_celeba_small")
    os.remove(os.path.join(dst_data_directory, ".DS_Store"))
    data_directory1 = os.path.join(dst_data_directory, "celeba")
    data_directory2 = os.path.join(data_directory1, "New Folder With Items")
    for element in os.listdir(data_directory1):
        if not element.endswith(".jpg"):
            continue

        if os.path.isfile(os.path.join(data_directory1, element)):
            shutil.move(os.path.join(data_directory1, element),
                        os.path.join(dst_data_directory, element))

    for element in os.listdir(data_directory2):
        if not element.endswith(".jpg"):
            continue

        if os.path.isfile(os.path.join(data_directory2, element)):
            shutil.move(os.path.join(data_directory2, element),
                        os.path.join(dst_data_directory, element))

    shutil.rmtree(data_directory1)
    print("finished")
예제 #19
0
파일: FaceNet.py 프로젝트: Neihtq/fasecure
def load_state():
    path = 'https://github.com/khrlimam/facenet/releases/download/acc-0.92135/model921-af60fb4f.pth'

    model_dir = "./pretrained_model"
    os.makedirs(model_dir, exist_ok=True)

    cached_file = os.path.join(model_dir, os.path.basename(path))
    if not os.path.exists(cached_file):
        download_url_to_file(path, cached_file)

    state_dict = torch.load(cached_file)  
    
    return state_dict
예제 #20
0
def custom_cache_url(url: str,
                     model_dir: Path = None,
                     progress: bool = True) -> Path:
    r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/models`` where
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be
overridden with the ``$TORCH_MODEL_ZOO`` environment variable.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
progress (bool, optional): whether or not to display a progress bar to stderr
Example:
>>> cached_file = maskrcnn_benchmark.utils.model_zoo.custom_cache_url(
'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
"""
    if model_dir is None:
        model_dir = os.getenv(
            "TORCH_MODEL_ZOO",
            Path(os.path.expanduser(os.getenv("TORCH_HOME", "~/.torch"))) /
            "models",
        )
    if not model_dir.exists():
        model_dir.mkdir(parents=True)
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if filename == "model_final.pkl":
        # workaround as pre-trained Caffe2 models from Detectron have all the same filename
        # so make the full path the filename by replacing / with _
        filename = parts.path.replace("/", "_")
    cached_file = model_dir / filename
    if not cached_file.exists() and is_main_process():
        sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n')
        hash_prefix = HASH_REGEX.search(filename)
        if hash_prefix is not None:
            hash_prefix = hash_prefix.group(1)
            # workaround: Caffe2 models don't have a hash, but follow the R-50 convention,
            # which matches the hash PyTorch uses. So we skip the hash matching
            # if the hash_prefix is less than 6 characters
            if len(hash_prefix) < 6:
                hash_prefix = None
        download_url_to_file(url,
                             str(cached_file),
                             hash_prefix,
                             progress=progress)
    synchronise_torch_barrier()
    return cached_file
예제 #21
0
def get_cached_file_path(url: str,
                         save_dir: Optional[str] = None,
                         progress: bool = True,
                         check_hash: bool = False,
                         file_name: Optional[str] = None) -> str:
    r"""Loads the Torch serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically decompressed

    If the object is already present in `model_dir`, it's deserialized and
    returned.
    The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
    ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (str): URL of the object to download
        save_dir (str, optional): directory in which to save the object
        progress (bool): whether or not to display a progress bar
            to stderr. Default: ``True``
        check_hash(bool): If True, the filename part of the URL
            should follow the naming convention ``filename-<sha256>.ext``
            where ``<sha256>`` is the first eight or more digits of the
            SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: ``False``
        file_name (str, optional): name for the downloaded file. Filename
            from ``url`` will be used if not set. Default: ``None``.

    Returns:
        str: The path to the cached file.
    """
    if save_dir is None:
        save_dir = os.path.join('webcam_resources')

    mkdir_or_exist(save_dir)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(save_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    return cached_file
예제 #22
0
def get_file_from_url(url, model_dir=None, progress=True, unzip=True):
    """
    Download model from url and return path to downloaded model.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        unzip (bool, optional): whether to use unzip. Don't use with TorchScript models!
            Default: True

    Returns: downloaded model path

    """
    if model_dir is None:
        torch_home = _get_torch_home()
        model_dir = os.path.join(torch_home, "checkpoints")

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        logger.info(f"Downloading: {url} to {cached_file}")
        download_url_to_file(url, cached_file, None, progress=progress)

    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
    #       E.g. resnet18-5c106cde.pth which is widely used.
    if zipfile.is_zipfile(cached_file) and unzip:
        with zipfile.ZipFile(cached_file) as cached_zipfile:
            members = cached_zipfile.infolist()
            if len(members) != 1:
                raise RuntimeError(
                    "Only one file(not dir) is allowed in the zipfile")
            cached_zipfile.extractall(model_dir)
            extraced_name = members[0].filename
            cached_file = os.path.join(model_dir, extraced_name)
    return cached_file
    def test_ppl_gen_model(self):
        model = os.path.basename(URL_SNGAN_MODEL)
        model = os.path.realpath(os.path.join(tempfile.gettempdir(), model))
        download_url_to_file(URL_SNGAN_MODEL, model, progress=True)
        self.assertTrue(os.path.isfile(model))

        print(f'Running fidelity PPL...', file=sys.stderr)
        res_fidelity = self.call_fidelity_ppl(model, 100)
        self.assertEqual(res_fidelity.returncode, 0, msg=res_fidelity)
        res_fidelity = json_decode_string(res_fidelity.stdout.decode())
        print('Fidelity PPL result:', res_fidelity, file=sys.stderr)

        self.assertAlmostEqual(res_fidelity[KEY_METRIC_PPL_MEAN],
                               2560.187255859375,
                               delta=1e-5)
예제 #24
0
def download_and_unpack_noises(directory):
    directory = PurePath(directory)
    os.makedirs(directory, exist_ok=True)
    if os.path.isdir(directory / 'noises-train-drones') and os.path.isdir(directory / 'noises-test-drones'):
        logger.info('Noises data seems to be already loaded')
        return
    
    logger.info('Downloading and extracting noises...')
    zip_path = directory / 'noises.zip'
    hub.download_url_to_file(NOISES_URL, zip_path)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(directory)
        
    os.remove(zip_path)
    logger.info('Done')
예제 #25
0
파일: hub.py 프로젝트: bernardomig/ark
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
    r"""A modified version of `torch.hub.load_state_dict_from_url`, which handles the new 
    serialization protocol diferently. 
    See <https://github.com/pytorch/pytorch/issues/43106> for more information.
    """
    import os
    import sys
    import warnings
    import errno
    import torch
    from urllib.parse import urlparse
    from torch.hub import get_dir, download_url_to_file, HASH_REGEX

    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    return torch.load(cached_file, map_location=map_location)
예제 #26
0
def get_datasets(versions):
    squad_links = {
        SQuADVersion.V11:
        "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
        SQuADVersion.V20:
        "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
    }
    filenames = {
        SQuADVersion.V11: "dev-v1.1.json",
        SQuADVersion.V20: "dev-v2.0.json"
    }
    data_dir = Path(".data") if is_server() else Path("data")
    datasets_path = data_dir / "nlp" / "squad"
    datasets_path.mkdir(parents=True, exist_ok=True)
    for version in versions:
        filename = datasets_path / filenames[version]
        if not filename.exists():
            download_url_to_file(squad_links[version], filename)
예제 #27
0
def main(argv):
    if not os.path.exists(args.weights):
        download_url_to_file(PRETRAINED_WEIGHTS_URL,
                             args.weights,
                             progress=True)

    device = args.device
    if not torch.cuda.is_available():
        device = 'cpu'

    if args.mode == 1:
        pass
    elif args.mode == 2:
        pass
    elif args.mode == 3:
        camfeed_inference(args.weights, device)
    else:
        raise ModeError
예제 #28
0
def unzip_file_from_url(output_dir, url):
    output_dir = os.path.realpath(output_dir)

    with tempfile.TemporaryDirectory() as temp_dir:
        parts = urlparse(url)
        filename = os.path.basename(parts.path)
        temp_path = os.path.join(temp_dir, filename)
        if not os.path.exists(temp_path):
            print('Downloading: "{}" to {}\n'.format(url, temp_path))
            download_url_to_file(url, temp_path)

        assert zipfile.is_zipfile(temp_path)

        with zipfile.ZipFile(temp_path) as temp_zipfile:
            for member in tqdm(temp_zipfile.infolist(),
                               desc='Extracting ',
                               ncols=0):
                temp_zipfile.extract(member, output_dir)
예제 #29
0
def load_dox_url(url,
                 filename,
                 model_dir=None,
                 map_location=None,
                 progress=True):
    r"""Adapt to fit format file of mtdp pre-trained models
    """
    if model_dir is None:
        torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
        model_dir = os.getenv('TORCH_MODEL_ZOO',
                              os.path.join(torch_home, 'models'))
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        sys.stderr.flush()
        download_url_to_file(url, cached_file, None, progress=progress)
    return torch.load(cached_file, map_location=map_location)
예제 #30
0
    def load_onnx_from_zoo(self, model_name, sensor, version="onnx"):
        model_url = self._get_zoo_model_url(model_name, sensor.name, version)
        hub_dir = os.path.join(hub.get_dir(), "checkpoints")

        if not os.path.isfile(hub_dir):
            os.makedirs(hub_dir)

        # from https://github.com/pytorch/pytorch/blob/master/torch/hub.py
        parts = urlparse(model_url)
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(hub_dir, filename)
        if not os.path.exists(cached_file):
            sys.stderr.write('Downloading: "{}" to {}\n'.format(
                model_url, cached_file))
            hub.download_url_to_file(model_url,
                                     cached_file,
                                     None,
                                     progress=True)

        return self.load_onnx_session(cached_file)