コード例 #1
0
def load_cifar(data_dir, nr_classes=10):
    assert nr_classes in (10, 100)

    data_file = 'cifar-{}-python.tar.gz'.format(nr_classes)
    origin = cifar_web_address + data_file
    dataset = osp.join(data_dir, data_file)
    if nr_classes == 10:
        folder_name = 'cifar-10-batches-py'
        filenames = ['data_batch_{}'.format(i) for i in range(1, 6)]
        filenames.append('test_batch')
    else:
        folder_name = 'cifar-100-python'
        filenames = ['train', 'test']

    if not osp.isdir(osp.join(data_dir, folder_name)):
        if not osp.isfile(dataset):
            download(origin, data_dir, data_file)
        tarfile.open(dataset, 'r:gz').extractall(data_dir)

    filenames = list(
        map(lambda x: osp.join(data_dir, folder_name, x), filenames))

    train_set = _read_cifar(filenames[:-1], nr_classes)
    test_set = _read_cifar([filenames[-1]], nr_classes)

    return train_set, test_set
コード例 #2
0
ファイル: mnist.py プロジェクト: ExplorerFreda/Jacinle
def load_mnist(
        data_dir,
        data_file='mnist.pkl.gz',
        origin='http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
):

    dataset = osp.join(data_dir, data_file)

    if (not osp.isfile(dataset)) and data_file == 'mnist.pkl.gz':
        download(origin, data_dir, data_file)

    # Load the dataset
    with gzip.open(dataset, 'rb') as f:
        try:
            train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
        except:
            train_set, valid_set, test_set = pickle.load(f)
    return train_set, valid_set, test_set
コード例 #3
0
def load_svhn(data_dir, extra=False):
    from scipy.io import loadmat

    all_set_keys = list(svhn_web_address.keys())
    if not extra:
        all_set_keys = all_set_keys[:2]

    all_sets = []

    for subset in all_set_keys:
        data_addr, data_file, data_hash = svhn_web_address[subset]

        dataset = os.path.join(data_dir, data_file)

        if not os.path.isfile(dataset):
            download(data_addr, data_dir, data_file, md5=data_hash)

        mat = loadmat(dataset)
        mat['X'] = np.transpose(mat['X'], [3, 0, 1, 2])
        all_sets.append((np.ascontiguousarray(mat['X']), mat['y']))

    return tuple(all_sets)