예제 #1
0
파일: svhn.py 프로젝트: jglombitza/dlipr
def load_data(extra=False):
    """Load the SVHN dataset (optionally with extra images)

    Args:
        extra (bool, optional): load extra training data

    Returns:
        Dataset: SVHN data
    """
    def load_mat(fname):
        data = sio.loadmat(fname)
        X = data['X'].transpose(3, 0, 1, 2)
        y = data['y'] % 10  # map label "10" --> "0"
        return X, y

    data = Dataset()
    data.classes = np.arange(10)

    fname = get_datapath('SVHN/%s_32x32.mat')

    X, y = load_mat(fname % 'train')
    data.train_images = X
    data.train_labels = y.reshape(-1)

    X, y = load_mat(fname % 'test')
    data.test_images = X
    data.test_labels = y.reshape(-1)

    if extra:
        X, y = load_mat(fname % 'extra')
        data.extra_images = X
        data.extra_labels = y.reshape(-1)

    return data
예제 #2
0
def load_data():
    """Load the dataset of images of simulated GISAXS measurements
    (Grazing Incidence Small-Angle X-ray Scattering).
    The dataset contains the speckled (noisy) images along with the underlying
    unspeckled images for training a denoising autoencoder.

    Returns:
        Dataset: Speckled and unspeckled images (20000 train, 5500 test)
    """
    data = Dataset()

    # monkey-patch the plot_examples function
    def monkeypatch_method(cls):
        def decorator(func):
            setattr(cls, func.__name__, func)
            return func

        return decorator

    @monkeypatch_method(Dataset)
    def plot_examples(self, num_examples=10, fname=None):
        """Plot the first examples of speckled and unspeckled images.

        Args:
            num_examples (int, optional): number of examples to plot for each class
            fname (str, optional): filename for saving the plot
        """
        fig, axes = plt.subplots(2, num_examples, figsize=(num_examples, 2))
        for i, X in enumerate((self.X_train, self.Y_train)):
            for j in range(num_examples):
                ax = axes[i, j]
                ax.imshow(X[j])
                ax.set_xticks([])
                ax.set_yticks([])
        axes[0, 0].set_ylabel('speckled')
        axes[1, 0].set_ylabel('unspeckled')
        maybe_savefig(fig, fname)

    fname = get_datapath('AutoEncoder/data.h5')
    fin = h5py.File(fname)['data']

    def format(X):
        return np.swapaxes(X, 0, 1).reshape((-1, 64, 64))

    speckle = format(fin['speckle_images'])
    normal = format(fin['normal_images'])

    data.X_train, data.X_test = np.split(speckle, [20000])
    data.Y_train, data.Y_test = np.split(normal, [20000])
    return data
예제 #3
0
파일: cifar.py 프로젝트: jglombitza/dlipr
def load_cifar100(label_key='fine_labels'):
    """Load CIFAR-100 data set using the 'fine_labels' or 'coarse_labels'.

    Returns:
        Dataset: CIFAR-100 data
    """
    print('Loading CIFAR-100 dataset with %s' % label_key)

    fname = get_datapath('CIFAR/cifar-100-data/%s.pickle')
    X_train, y_train = load_pickle(fname % 'train', label_key)
    X_test, y_test = load_pickle(fname % 'test', label_key)

    data = Dataset()

    if label_key == 'fine_labels':
        data.classes = cifar100_fine_labels
        # stored integers refer to alphabetically sorted fine labels
        sorting = np.argsort(data.classes)
        y_train = np.array([sorting[y] for y in y_train], dtype='uint8')
        y_test = np.array([sorting[y] for y in y_test], dtype='uint8')
    else:
        data.classes = cifar100_coarse_labels

    data.train_images = X_train
    data.train_labels = y_train
    data.test_images = X_test
    data.test_labels = y_test
    return data
예제 #4
0
파일: cifar.py 프로젝트: jglombitza/dlipr
def load_cifar10():
    """Load the CIFAR-10 data set.

    Returns:
        Dataset: CIFAR-10 data
    """
    print('Loading CIFAR-10 dataset')

    fname = get_datapath('CIFAR/cifar-10-data/batch_%i.pickle')
    X_train = np.empty((50000, 32, 32, 3), dtype='uint8')
    y_train = np.empty((50000), dtype='uint8')

    for i in range(5):
        X, y = load_pickle(fname % (i + 1))
        X_train[i * 10000:(i + 1) * 10000] = X
        y_train[i * 10000:(i + 1) * 10000] = y

    X_test, y_test = load_pickle(fname % 6)

    data = Dataset()
    data.classes = cifar10_labels
    data.train_images = X_train
    data.train_labels = y_train
    data.test_images = X_test
    data.test_labels = y_test
    return data
예제 #5
0
def load_data():
    """Load a small dataset of flower photos (daisies, dandelions, roses, sunflowers, tulips)

    Returns:
        Dataset: flower photos and labels
    """
    fname = get_datapath('flower_photos/flowers_224.npz')
    data = np.load(fname)
    X_train, X_test = np.split(data['X'], [-1000])
    y_train, y_test = np.split(data['y'], [-1000])

    data = Dataset()
    data.classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    data.train_images = X_train
    data.train_labels = y_train
    data.test_images = X_test
    data.test_labels = y_test
    return data
예제 #6
0
파일: mnist.py 프로젝트: jglombitza/dlipr
def load_data():
    """Load the MNIST dataset.

    Returns:
        Dataset: MNIST data
    """
    def _read32(bytestream):
        dt = np.dtype(np.uint32).newbyteorder('>')
        return np.frombuffer(bytestream.read(4), dtype=dt)[0]

    def _extract_images(fname):
        with gzip.GzipFile(fileobj=open(fname, 'rb')) as bytestream:
            _read32(bytestream)
            num_images = _read32(bytestream)
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            buf = bytestream.read(rows * cols * num_images)
            data = np.frombuffer(buf, dtype=np.uint8)
            return data.reshape(num_images, rows, cols)

    def _extract_labels(fname):
        with gzip.GzipFile(fileobj=open(fname, 'rb')) as bytestream:
            _read32(bytestream)
            num_items = _read32(bytestream)
            buf = bytestream.read(num_items)
            return np.frombuffer(buf, dtype=np.uint8)

    data = Dataset()
    data.train_images = _extract_images(
        get_datapath('MNIST/train-images-idx3-ubyte.gz'))
    data.train_labels = _extract_labels(
        get_datapath('MNIST/train-labels-idx1-ubyte.gz'))
    data.test_images = _extract_images(
        get_datapath('MNIST/t10k-images-idx3-ubyte.gz'))
    data.test_labels = _extract_labels(
        get_datapath('MNIST/t10k-labels-idx1-ubyte.gz'))
    data.classes = np.arange(10)
    return data
예제 #7
0
파일: ising.py 프로젝트: jglombitza/dlipr
def load_data():
    """Load the Ising data set.

    Returns:
        Dataset: Ising data
    """
    data = np.load(get_datapath('Ising/data.npz'))
    X = data['C']
    y = data['T']

    temperatures = np.arange(1, 3.51, 0.1)
    y = np.searchsorted(temperatures, y)

    X_train, X_test = np.split(X, [22000])
    y_train, y_test = np.split(y, [22000])

    data = Dataset()
    data.classes = np.around(temperatures, decimals=1)
    data.train_images = X_train
    data.train_labels = y_train
    data.test_images = X_test
    data.test_labels = y_test
    return data