예제 #1
0
    def download(self):
        import tarfile

        if self.voc_root.is_dir():
            print("VOC found. Skip download or extract")
        else:
            download_url(self.url, self.root, self.filename, self.md5)

            with tarfile.open(self.root / self.filename, "r") as tar:
                tar.extractall(path=self.root)

        if self.augmented:
            mask_dir = self.voc_root / 'SegmentationClassAug'
            if mask_dir.is_dir():
                print("SBT found. Skip download or extract")
            else:
                file_id = re.match(r"https://drive.google.com/open\?id=(.*)",
                                   TRAINAUG_FILE['url']).group(1)
                filename = TRAINAUG_FILE['name']
                download_google_drive(file_id, self.voc_root, filename,
                                      TRAINAUG_FILE['md5'])

                file_path = self.voc_root / filename
                with tarfile.open(file_path, "r") as tar:
                    tar.extractall(path=self.voc_root)
                split_f = self.voc_root / 'trainaug.txt'
                splits_dir = self.voc_root / 'ImageSets' / 'Segmentation'
                split_f.rename(splits_dir / split_f.name)
예제 #2
0
    def download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        info = self.split_info[self.split]
        download_google_drive(info['url'], self.root, info['tgz_filename'],
                              info['tgz_md5'])

        # extract file
        with tarfile.open(os.path.join(self.root, info['tgz_filename']),
                          "r:gz") as tar:
            tar.extractall(path=self.root)
예제 #3
0
    def download(self):
        import tarfile
        if self.voc_root.is_dir() and self.ann_file.exists():
            print("Dataset found. Skip download or extract.")
            return

        if not self.voc_root.is_dir():
            download_url(self.url, self.root, self.filename, self.md5)
            with tarfile.open(self.root / self.filename, "r") as tar:
                tar.extractall(path=self.root)

        if not self.ann_file.exists():
            google_drive_match = re.match(
                r"https://drive.google.com/open\?id=(.*)", self.ann_file_url)
            file_id = google_drive_match.group(1)
            download_google_drive(file_id, self.voc_root, self.ann_file.name)
예제 #4
0
    def download(self):
        if self.img_dir.is_dir() and self.ann_file.exists():
            print("Dataset found. Skip download or extract")
            return

        google_drive_match = re.match(
            r"https://drive.google.com/open\?id=(.*)", self.url)
        if google_drive_match:
            file_id = google_drive_match.group(1)
            download_google_drive(file_id, self.root, self.filename, self.md5)
        else:
            download_url(self.url, self.root, self.filename, self.md5)

        file_path = self.root / self.filename
        with tarfile.open(file_path, "r") as tar:
            tar.extractall(path=self.root)
        ann_file = self.root / self.ann_filename
        ann_file.rename(self.ann_file)
예제 #5
0
def load_state_dict_from_google_drive(file_id,
                                      filename,
                                      md5,
                                      model_dir=None,
                                      map_location=None):
    r"""Loads the Torch serialized object at the given URL.

    If the object is already present in `model_dir`, it's deserialized and
    returned. The filename part of the URL should follow the naming convention
    ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
    digits of the SHA256 hash of the contents of the file. The hash is used to
    ensure unique names and to verify the contents of the file.

    The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesytem layout, with a default value ``~/.cache`` if not set.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr

    """

    if model_dir is None:
        from torch.hub import _get_torch_home
        torch_home = _get_torch_home()
        torch_home = Path(torch_home)
        model_dir = torch_home / 'checkpoints'
    else:
        model_dir = Path(model_dir)

    model_dir.mkdir(parents=True, exist_ok=True)

    download_google_drive(file_id, model_dir, filename, md5)
    cached_file = model_dir / filename
    return torch.load(cached_file, map_location=map_location)