def download_tiny_imagenet(): url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" dir_data = os.path.join(get_data_home(), 'tiny-imagenet-200') if not os.path.isdir(dir_data): f = download(url) logger.info('Extracting {} ...'.format(f.name)) z = zipfile.ZipFile(f) d = get_data_home() l = z.namelist() for i in tqdm(range(len(l))): z.extract(l[i], d) z.close() f.close() return dir_data
def prepare_pix2pix_dataset(dataset="edges2shoes", train=True): imgs_A, imgs_B, names = load_pix2pix_dataset(dataset=dataset, train=train) dname = "train" if train else "val" dpath_A = os.path.join(get_data_home(), "{}_A".format(dataset), dname) dpath_B = os.path.join(get_data_home(), "{}_B".format(dataset), dname) def save_img(dpath, names, imgs): if os.path.exists(dpath): os.rmdir(dpath) os.makedirs(dpath) else: os.makedirs(dpath) for name, img in zip(names, imgs): fpath = os.path.join(dpath, name) img = img.transpose((1, 2, 0)) logger.info("Save img to {}".format(fpath)) scipy.misc.imsave(fpath, img) save_img(dpath_A, names, imgs_A) save_img(dpath_B, names, imgs_B)
def load_imdb(vocab_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: file_name = 'imdb.npz' url = f'https://s3.amazonaws.com/text-datasets/{file_name}' download(url, open_file=False) dataset_path = Path(get_data_home()) / file_name unk_index = vocab_size - 1 raw = load_npy(dataset_path) ret = dict() for k, v in raw.items(): if 'x' in k: for i, sentence in enumerate(v): v[i] = [word if word < unk_index else unk_index for word in sentence] ret[k] = v return ret['x_train'], ret['x_test'], ret['y_train'], ret['y_test']
def __init__(self, train=True, shuffle=False, rng=None): super(Cifar100DataSource, self).__init__(shuffle=shuffle) # Lock lockfile = os.path.join(get_data_home(), "cifar100.lock") start_time = time.time() while True: # busy-lock due to communication between process spawn by mpirun try: fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR) break except OSError as e: if e.errno != errno.EEXIST: raise if (time.time() - start_time) >= 60 * 30: # wait for 30min raise Exception( "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted." ) time.sleep(5) self._train = train data_uri = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" logger.info('Getting labeled data from {}.'.format(data_uri)) r = download(data_uri) # file object returned with tarfile.open(fileobj=r, mode="r:gz") as fpin: # Training data if train: images = [] labels = [] for member in fpin.getmembers(): if "train" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes") images = data[b"data"] labels = data[b"fine_labels"] self._size = 50000 self._images = images.reshape(self._size, 3, 32, 32) self._labels = np.array(labels).reshape(-1, 1) # Validation data else: for member in fpin.getmembers(): if "test" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes") images = data[b"data"] labels = data[b"fine_labels"] self._size = 10000 self._images = images.reshape(self._size, 3, 32, 32) self._labels = np.array(labels).reshape(-1, 1) r.close() logger.info('Getting labeled data from {} done.'.format(data_uri)) self._size = self._labels.size self._variables = ('x', 'y') if rng is None: rng = np.random.RandomState(313) self.rng = rng self.reset() # Unlock os.close(fd) os.unlink(lockfile)
def __init__(self, train=True, shuffle=False, rng=None): super(Cifar10DataSource, self).__init__(shuffle=shuffle) # Lock lockfile = os.path.join(get_data_home(), "cifar10.lock") start_time = time.time() while True: # busy-lock due to communication between process spawn by mpirun try: fd = os.open(lockfile, os.O_CREAT | os.O_EXCL | os.O_RDWR) break except OSError as e: if e.errno != errno.EEXIST: raise if (time.time() - start_time) >= 60 * 30: # wait for 30min raise Exception( "Timeout occured. If there are cifar10.lock in $HOME/nnabla_data, it should be deleted.") time.sleep(5) self._train = train data_uri = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" logger.info('Getting labeled data from {}.'.format(data_uri)) r = download(data_uri) # file object returned with tarfile.open(fileobj=r, mode="r:gz") as fpin: # Training data if train: images = [] labels = [] for member in fpin.getmembers(): if "data_batch" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes") images.append(data[b"data"]) labels.append(data[b"labels"]) self._size = 50000 self._images = np.concatenate( images).reshape(self._size, 3, 32, 32) self._labels = np.concatenate(labels).reshape(-1, 1) # Validation data else: for member in fpin.getmembers(): if "test_batch" not in member.name: continue fp = fpin.extractfile(member) data = np.load(fp, encoding="bytes") images = data[b"data"] labels = data[b"labels"] self._size = 10000 self._images = images.reshape(self._size, 3, 32, 32) self._labels = np.array(labels).reshape(-1, 1) r.close() logger.info('Getting labeled data from {}.'.format(data_uri)) self._size = self._labels.size self._variables = ('x', 'y') if rng is None: rng = np.random.RandomState(313) self.rng = rng self.reset() # Unlock os.close(fd) os.unlink(lockfile)