示例#1
0
文件: svhn.py 项目: ylfzr/ADGM
def _download(extra=False, normalize=True):
    """
    Download the SVHN dataset if it is not present.
    :return: The train, test and validation set.
    """

    def norm(x):
        x = x.swapaxes(2,3).swapaxes(1,2)
        x = x.reshape((-1, 3, 32 * 32))
        std = x.std(axis=(-1, 0))
        x[:, 0] /= std[0]
        x[:, 1] /= std[1]
        x[:, 2] /= std[2]
        x = x.reshape((-1, 3 * 32 * 32))
        return x

    train_x, train_t, test_x, test_t = load_svhn(os.path.join(env_paths.get_data_path("svhn"), ""),
                                                 normalize=False,
                                                 dequantify=True,
                                                 extra=extra)

    if normalize:
        train_x = norm(train_x)
        test_x = norm(test_x)

    train_t = np.array(train_t, dtype='float32').reshape(-1)
    test_t = np.array(test_t, dtype='float32').reshape(-1)

    # Dummy validation set. NOTE: still in training set.
    idx = np.random.randint(0, train_x.shape[0] - 1, 5000)
    valid_x = train_x[idx, :]
    valid_t = train_t[idx]

    return (train_x, train_t), (test_x, test_t), (valid_x, valid_t)
示例#2
0
def _download():
    """
    Download the MNIST dataset if it is not present.
    :return: The train, test and validation set.
    """
    dataset = 'mnist.pkl.gz'
    data_dir, data_file = os.path.split(dataset)
    if data_dir == "" and not os.path.isfile(dataset):
        # Check if dataset is in the data directory.
        new_path = os.path.join(
            env_paths.get_data_path(),
            dataset
        )
        if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':
            dataset = new_path

    if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
        import urllib
        origin = (
            'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
        )
        print 'Downloading data from %s' % origin
        urllib.urlretrieve(origin, dataset)

    f = gzip.open(dataset, 'rb')
    train_set, valid_set, test_set = cPickle.load(f)
    f.close()
    return train_set, test_set, valid_set
def _download():
    """
    Download the MNIST dataset if it is not present.
    :return: The train, test and validation set.
    """
    data = load_mnist_realval(os.path.join(env_paths.get_data_path("mnist"), "mnist.pkl.gz"))
    train_x, train_t, valid_x, valid_t, test_x, test_t = data
    return (train_x, train_t), (test_x, test_t), (valid_x, valid_t)
示例#4
0
def _download():
    """
    Download the MNIST dataset if it is not present.
    :return: The train, test and validation set.
    """
    data = load_mnist_realval(os.path.join(env_paths.get_data_path("mnist"), "mnist.pkl.gz"))
    train_x, train_t, valid_x, valid_t, test_x, test_t = data
    return (train_x, train_t), (test_x, test_t), (valid_x, valid_t)
示例#5
0
    def load_data(data_file):
        # set temp environ data path for pylearn2.
        os.environ['PYLEARN2_DATA_PATH'] = env_paths.get_data_path("norb")

        data_dir = os.path.join(os.environ['PYLEARN2_DATA_PATH'], 'norb_small', 'original')
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        dataset = os.path.join(data_dir, data_file)

        if (not os.path.isfile(dataset)):
            import urllib
            origin = (
                os.path.join('http://www.cs.nyu.edu/~ylclab/data/norb-v1.0-small/', data_file)
            )
            logger.info('Downloading data from %s', origin)

            urllib.urlretrieve(origin, dataset)
        return dataset
示例#6
0
    def load_data(data_file):
        # set temp environ data path for pylearn2.
        os.environ['PYLEARN2_DATA_PATH'] = env_paths.get_data_path("norb")

        data_dir = os.path.join(os.environ['PYLEARN2_DATA_PATH'], 'norb_small', 'original')
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        dataset = os.path.join(data_dir, data_file)

        if (not os.path.isfile(dataset)):
            import urllib
            origin = (
                os.path.join('http://www.cs.nyu.edu/~ylclab/data/norb-v1.0-small/', data_file)
            )
            print 'Downloading data from %s' % origin

            urllib.urlretrieve(origin, dataset)
        return dataset
示例#7
0
def _load(max_parse_size):
    """
    Download the AG News dataset if it is not present.
    :return: The train, test and validation set.
    """

    source_path = env_paths.get_data_path("ag_news")
    package_path = os.path.join(source_path, "ag_news_csv.tar.gz")
    extracted_path = os.path.join(source_path, "ag_news_csv")
    processed_path = os.path.join(source_path, "dataset_%i.npz" % max_parse_size)

    if os.path.isfile(processed_path):
        train_set, test_set = np.load(open(processed_path, "rb"))
        return train_set, test_set, None

    origin = (
        "https://drive.google.com/uc?id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms&export=download"
    )
    print 'Downloading data from %s' % origin
    urllib.urlretrieve(origin, package_path)

    def csv_files(members):
        for tarinfo in members:
            if os.path.splitext(tarinfo.name)[1] == ".csv":
                yield tarinfo

    print 'Extracting data to %s' % extracted_path
    tar = tarfile.open(package_path)
    tar.extractall(path=source_path, members=csv_files(tar))
    tar.close()

    train = pd.read_csv(os.path.join(extracted_path, "train.csv"), header=None).values
    test = pd.read_csv(os.path.join(extracted_path, "test.csv"), header=None).values

    # remove downloaded and extracted files.
    os.remove(package_path)
    shutil.rmtree(extracted_path)

    def transform(dat):
        dat[:, 1] += " " + dat[:, -1]
        dat = np.delete(dat, -1, axis=1)
        x = dat[:, 1]
        t = np.array(dat[:, 0], dtype='float32')
        return x, t

    print 'Transforming data'
    train_set = transform(train)
    test_set = transform(test)

    def _parse(xy, max_size):
        alphabet = np.array(list("abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\|_@#$%^&*~`+-=<>()[]{}"))
        x, y = xy
        new_x = np.zeros((x.shape[0], max_size))
        for i in range(x.shape[0]):
            line = list(x[i].lower())
            for j in range(len(line[:max_size])):
                char = line[j]
                if char in alphabet:
                    idx = np.where(alphabet == char)[0][0]
                    new_x[i, j] = idx
        y -= 1.
        return new_x, y

    print 'Parsing data'
    train_set = _parse(train_set, max_parse_size)
    test_set = _parse(test_set, max_parse_size)

    print 'Dump data'
    np.save(open(processed_path, "wb"), (train_set, test_set))

    return train_set, test_set, None