コード例 #1
0
def download_tiny_imagenet():
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    dir_data = os.path.join(get_data_home(), 'tiny-imagenet-200')
    if not os.path.isdir(dir_data):
        f = download(url)
        logger.info('Extracting {} ...'.format(f.name))
        z = zipfile.ZipFile(f)
        d = get_data_home()
        l = z.namelist()
        for i in tqdm(range(len(l))):
            z.extract(l[i], d)
        z.close()
        f.close()
    return dir_data
コード例 #2
0
ファイル: datasets.py プロジェクト: sony/nnabla-examples
def prepare_pix2pix_dataset(dataset="edges2shoes", train=True):
    imgs_A, imgs_B, names = load_pix2pix_dataset(dataset=dataset, train=train)
    dname = "train" if train else "val"
    dpath_A = os.path.join(get_data_home(), "{}_A".format(dataset), dname)
    dpath_B = os.path.join(get_data_home(), "{}_B".format(dataset), dname)

    def save_img(dpath, names, imgs):
        if os.path.exists(dpath):
            os.rmdir(dpath)
            os.makedirs(dpath)
        else:
            os.makedirs(dpath)
        for name, img in zip(names, imgs):
            fpath = os.path.join(dpath, name)
            img = img.transpose((1, 2, 0))
            logger.info("Save img to {}".format(fpath))
            scipy.misc.imsave(fpath, img)
    save_img(dpath_A, names, imgs_A)
    save_img(dpath_B, names, imgs_B)
コード例 #3
0
def load_imdb(vocab_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    file_name = 'imdb.npz'
    url = f'https://s3.amazonaws.com/text-datasets/{file_name}'
    download(url, open_file=False)

    dataset_path = Path(get_data_home()) / file_name

    unk_index = vocab_size - 1
    raw = load_npy(dataset_path)
    ret = dict()
    for k, v in raw.items():
        if 'x' in k:
            for i, sentence in enumerate(v):
                v[i] = [word if word < unk_index else unk_index for word in sentence]
        ret[k] = v
    return ret['x_train'], ret['x_test'], ret['y_train'], ret['y_test']
コード例 #4
0
    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar100DataSource, self).__init__(shuffle=shuffle)

        # Lock
        lockfile = os.path.join(get_data_home(), "cifar100.lock")
        start_time = time.time()
        while True:  # busy-lock due to communication between process spawn by mpirun
            try:
                fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
                break
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                if (time.time() - start_time) >= 60 * 30:  # wait for 30min
                    raise Exception(
                        "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted."
                    )

            time.sleep(5)

        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "train" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"fine_labels"]
                self._size = 50000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"fine_labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {} done.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

        # Unlock
        os.close(fd)
        os.unlink(lockfile)
コード例 #5
0
ファイル: cifar10_data.py プロジェクト: zwsong/nnabla
    def __init__(self, train=True, shuffle=False, rng=None):
        super(Cifar10DataSource, self).__init__(shuffle=shuffle)

        # Lock
        lockfile = os.path.join(get_data_home(), "cifar10.lock")
        start_time = time.time()
        while True:  # busy-lock due to communication between process spawn by mpirun
            try:
                fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR)
                break
            except OSError as e:
                if e.errno != errno.EEXIST:
                    raise
                if (time.time() - start_time) >= 60 * 30:  # wait for 30min
                    raise Exception(
                        "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted.")

            time.sleep(5)

        self._train = train
        data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        logger.info('Getting labeled data from {}.'.format(data_uri))
        r = download(data_uri)  # file object returned
        with tarfile.open(fileobj=r, mode="r:gz") as fpin:
            # Training data
            if train:
                images = []
                labels = []
                for member in fpin.getmembers():
                    if "data_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images.append(data[b"data"])
                    labels.append(data[b"labels"])
                self._size = 50000
                self._images = np.concatenate(
                    images).reshape(self._size, 3, 32, 32)
                self._labels = np.concatenate(labels).reshape(-1, 1)
            # Validation data
            else:
                for member in fpin.getmembers():
                    if "test_batch" not in member.name:
                        continue
                    fp = fpin.extractfile(member)
                    data = np.load(fp, encoding="bytes")
                    images = data[b"data"]
                    labels = data[b"labels"]
                self._size = 10000
                self._images = images.reshape(self._size, 3, 32, 32)
                self._labels = np.array(labels).reshape(-1, 1)
        r.close()
        logger.info('Getting labeled data from {}.'.format(data_uri))

        self._size = self._labels.size
        self._variables = ('x', 'y')
        if rng is None:
            rng = np.random.RandomState(313)
        self.rng = rng
        self.reset()

        # Unlock
        os.close(fd)
        os.unlink(lockfile)