def load_cifar(data_dir, nr_classes=10): assert nr_classes in (10, 100) data_file = 'cifar-{}-python.tar.gz'.format(nr_classes) origin = cifar_web_address + data_file dataset = osp.join(data_dir, data_file) if nr_classes == 10: folder_name = 'cifar-10-batches-py' filenames = ['data_batch_{}'.format(i) for i in range(1, 6)] filenames.append('test_batch') else: folder_name = 'cifar-100-python' filenames = ['train', 'test'] if not osp.isdir(osp.join(data_dir, folder_name)): if not osp.isfile(dataset): download(origin, data_dir, data_file) tarfile.open(dataset, 'r:gz').extractall(data_dir) filenames = list( map(lambda x: osp.join(data_dir, folder_name, x), filenames)) train_set = _read_cifar(filenames[:-1], nr_classes) test_set = _read_cifar([filenames[-1]], nr_classes) return train_set, test_set
def load_mnist( data_dir, data_file='mnist.pkl.gz', origin='http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' ): dataset = osp.join(data_dir, data_file) if (not osp.isfile(dataset)) and data_file == 'mnist.pkl.gz': download(origin, data_dir, data_file) # Load the dataset with gzip.open(dataset, 'rb') as f: try: train_set, valid_set, test_set = pickle.load(f, encoding='latin1') except: train_set, valid_set, test_set = pickle.load(f) return train_set, valid_set, test_set
def load_svhn(data_dir, extra=False): from scipy.io import loadmat all_set_keys = list(svhn_web_address.keys()) if not extra: all_set_keys = all_set_keys[:2] all_sets = [] for subset in all_set_keys: data_addr, data_file, data_hash = svhn_web_address[subset] dataset = os.path.join(data_dir, data_file) if not os.path.isfile(dataset): download(data_addr, data_dir, data_file, md5=data_hash) mat = loadmat(dataset) mat['X'] = np.transpose(mat['X'], [3, 0, 1, 2]) all_sets.append((np.ascontiguousarray(mat['X']), mat['y'])) return tuple(all_sets)