예제 #1
0
def get_dSprite(data_dir='', download=True):
    dSprite_filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
    file_path = os.path.join(data_dir, dSprite_filename)

    url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true'

    if not os.path.isfile(file_path) and download:
        try:
            print('Downloading ' + url + ' to ' + file_path)
            urllib.request.urlretrieve(url,
                                       file_path,
                                       reporthook=gen_bar_updater())
        except (urllib.error.URLError, IOError) as e:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + file_path)
                urllib.request.urlretrieve(url,
                                           file_path,
                                           reporthook=gen_bar_updater())
            else:
                raise e

    print('Loading {}'.format(file_path))
    dataset_zip = np.load(file_path, encoding='latin1', allow_pickle=True)
    return dataset_zip
예제 #2
0
def get_texture_images(data_dir='', download=True):
    url = 'https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz'
    tar_filename = url.split('/')[-1]
    file_path = os.path.join(data_dir, tar_filename)

    if not os.path.isdir('textures') and download:
        try:
            print('Downloading ' + url + ' to ' + data_dir)
            urllib.request.urlretrieve(url,
                                       file_path,
                                       reporthook=gen_bar_updater())
        except (urllib.error.URLError, IOError) as e:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + file_path)
                urllib.request.urlretrieve(url,
                                           file_path,
                                           reporthook=gen_bar_updater())
            else:
                raise e

        print('Extracting images')
        tar = tarfile.open(tar_filename)
        tar.extractall()
        tar.close()

        print('Extracting texture patches')
        textures_dir = os.path.join(data_dir, 'textures')
        os.mkdir(textures_dir)

        img_filenames = [
            'banded_0022.jpg', 'grid_0079.jpg', 'zigzagged_0024.jpg'
        ]
        crop_list = [70, 200, 160]

        for i, (img_filename,
                crop) in tqdm(enumerate(zip(img_filenames, crop_list))):
            category_folder = img_filename.split('_')[0]
            path = os.path.join('dtd/images', category_folder, img_filename)
            texture = skimage.io.imread(path)
            texture = center_crop(texture, crop)
            texture = resize(texture, 28)
            texture = (texture * 255).astype(np.uint8)

            save_filename = 'crop_' + img_filename
            save_path = os.path.join(textures_dir, save_filename)
            skimage.io.imsave(save_path, texture)

        print('Cleaning up unnecessary files')
        os.remove(file_path)
        shutil.rmtree('dtd')
예제 #3
0
def download_url(url, root, filename=None, md5=None):
    """Download a file from a url and place it in root.
    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    from six.moves import urllib

    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)

    os.makedirs(root, exist_ok=True)

    # check if file is already present locally
    if check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:  # download the file
        try:
            if 'dropbox' in url:
                # Handle dropbox links differently
                import requests
                headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
                r = requests.get(url, stream=True, headers=headers)
                with open(fpath, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=1024):
                        if chunk:
                            f.write(chunk)
            elif 'Manual' in url:
                raise urllib.error.URLError(url)
            else:
                print('Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(url,
                                           fpath,
                                           reporthook=gen_bar_updater())
        except (urllib.error.URLError, IOError) as e:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(url,
                                           fpath,
                                           reporthook=gen_bar_updater())
            else:
                raise e
        # check integrity of downloaded file
        if not check_integrity(fpath, md5):
            raise RuntimeError("File not found or corrupted.")
예제 #4
0
def download(url, path):
    if not os.path.exists(path):
        os.makedirs(path)
    file_name = os.path.join(path, url.split("/")[-1])
    if os.path.exists(file_name):
        print(f"Dataset already downloaded at {file_name}.")
    else:
        urllib.request.urlretrieve(url, file_name, reporthook=gen_bar_updater())
    return file_name
예제 #5
0
파일: utils.py 프로젝트: jwitos/torchio
def download_url(
    url: str,
    root: TypePath,
    filename: Optional[TypePath] = None,
    md5: str = None,
) -> None:
    """Download a file from a url and place it in root.

    Args:
        url: URL to download file from
        root: Directory to place downloaded file in
        filename: Name to save the file under.
            If ``None``, use the basename of the URL
        md5: MD5 checksum of the download. If None, do not check
    """
    import urllib
    from torchvision.datasets.utils import check_integrity, gen_bar_updater

    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)
    os.makedirs(root, exist_ok=True)
    # check if file is already present locally
    if not check_integrity(fpath, md5):
        try:
            print('Downloading ' + url + ' to ' + fpath)  # noqa: T001
            urllib.request.urlretrieve(url,
                                       fpath,
                                       reporthook=gen_bar_updater())
        except (urllib.error.URLError, OSError) as e:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                message = ('Failed download. Trying https -> http instead.'
                           ' Downloading ' + url + ' to ' + fpath)
                print(message)  # noqa: T001
                urllib.request.urlretrieve(url,
                                           fpath,
                                           reporthook=gen_bar_updater())
            else:
                raise e
        # check integrity of downloaded file
        if not check_integrity(fpath, md5):
            raise RuntimeError('File not found or corrupted.')