def test_read_labels():
    with tempfile.TemporaryFile() as f:
        data = struct.pack('>iiBBBB', MNIST_LABEL_MAGIC, 4, 9, 4, 3, 1)
        f.write(data)
        f.seek(0)
        arr = read_mnist_labels(f)
        assert arr.shape == (4,)
        assert arr.dtype == numpy.dtype('uint8')
        assert arr[0] == 9
        assert arr[1] == 4
        assert arr[2] == 3
        assert arr[3] == 1
示例#2
0
    def __init__(self, which_set, center=False, shuffle=False,
                 one_hot=None, binarize=False, start=None,
                 stop=None, axes=['b', 0, 1, 'c'],
                 preprocessor=None,
                 fit_preprocessor=False,
                 fit_test_preprocessor=False):
        self.args = locals()

        if which_set not in ['train', 'test']:
            if which_set == 'valid':
                raise ValueError(
                    "There is no such thing as the MNIST validation set. MNIST"
                    "consists of 60,000 train examples and 10,000 test"
                    "examples. If you wish to use a validation set you should"
                    "divide the train set yourself. The pylearn2 dataset"
                    "implements and will only ever implement the standard"
                    "train / test split used in the literature.")
            raise ValueError(
                'Unrecognized which_set value "%s".' % (which_set,) +
                '". Valid values are ["train","test"].')

        def dimshuffle(b01c):
            """
            .. todo::

                WRITEME
            """
            default = ('b', 0, 1, 'c')
            return b01c.transpose(*[default.index(axis) for axis in axes])

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'
            # Path substitution done here in order to make the lower-level
            # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
            # the Deep Learning Tutorials, or in another package).
            im_path = serial.preprocess(im_path)
            label_path = serial.preprocess(label_path)

            # Locally cache the files before reading them
            datasetCache = cache.datasetCache
            im_path = datasetCache.cache_file(im_path)
            label_path = datasetCache.cache_file(label_path)

            topo_view = read_mnist_images(im_path, dtype='float32')
            y = np.atleast_2d(read_mnist_labels(label_path)).T
        else:
            if which_set == 'train':
                size = 60000
            elif which_set == 'test':
                size = 10000
            else:
                raise ValueError(
                    'Unrecognized which_set value "%s".' % (which_set,) +
                    '". Valid values are ["train","test"].')
            topo_view = np.random.rand(size, 28, 28)
            y = np.random.randint(0, 10, (size, 1))

        if binarize:
            topo_view = (topo_view > 0.5).astype('float32')

        max_labels = 10
        if one_hot is not None:
            warnings.warn("the `one_hot` parameter is deprecated. To get "
                          "one-hot encoded targets, request that they "
                          "live in `VectorSpace` through the `data_specs` "
                          "parameter of MNIST's iterator method. "
                          "`one_hot` will be removed on or after "
                          "September 20, 2014.", stacklevel=2)

        m, r, c = topo_view.shape
        assert r == 28
        assert c == 28
        topo_view = topo_view.reshape(m, r, c, 1)

        if which_set == 'train':
            assert m == 60000
        elif which_set == 'test':
            assert m == 10000
        else:
            assert False

        if center:
            topo_view -= topo_view.mean(axis=0)

        if shuffle:
            self.shuffle_rng = make_np_rng(None, [1, 2, 3], which_method="shuffle")
            for i in xrange(topo_view.shape[0]):
                j = self.shuffle_rng.randint(m)
                # Copy ensures that memory is not aliased.
                tmp = topo_view[i, :, :, :].copy()
                topo_view[i, :, :, :] = topo_view[j, :, :, :]
                topo_view[j, :, :, :] = tmp
                # Note: slicing with i:i+1 works for one_hot=True/False
                tmp = y[i:i+1].copy()
                y[i] = y[j]
                y[j] = tmp

        super(MNIST, self).__init__(topo_view=dimshuffle(topo_view), y=y,
                                    axes=axes, y_labels=max_labels)

        assert not N.any(N.isnan(self.X))

        if start is not None:
            assert start >= 0
            if stop > self.X.shape[0]:
                raise ValueError('stop=' + str(stop) + '>' +
                                 'm=' + str(self.X.shape[0]))
            assert stop > start
            self.X = self.X[start:stop, :]
            if self.X.shape[0] != stop - start:
                raise ValueError("X.shape[0]: %d. start: %d stop: %d"
                                 % (self.X.shape[0], start, stop))
            if len(self.y.shape) > 1:
                self.y = self.y[start:stop, :]
            else:
                self.y = self.y[start:stop]
            assert self.y.shape[0] == stop - start

        if which_set == 'test':
            assert fit_test_preprocessor is None or \
                (fit_preprocessor == fit_test_preprocessor)

        if self.X is not None and preprocessor:
            preprocessor.apply(self, fit_preprocessor)
示例#3
0
    def __init__(self,
                 which_set,
                 center=False,
                 shuffle=False,
                 one_hot=False,
                 binarize=False,
                 start=None,
                 stop=None):

        self.args = locals()

        if which_set not in ['train', 'test']:
            if which_set == 'valid':
                raise ValueError(
                    "There is no such thing as the MNIST "
                    "validation set. MNIST consists of 60,000 train examples and 10,000 test"
                    " examples. If you wish to use a validation set you should divide the train "
                    "set yourself. The pylearn2 dataset implements and will only ever implement "
                    "the standard train / test split used in the literature.")
            raise ValueError('Unrecognized which_set value "%s".' %
                             (which_set, ) +
                             '". Valid values are ["train","test"].')

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'

            topo_view = read_mnist_images(im_path, dtype='float32')
            y = read_mnist_labels(label_path)

            if binarize:
                topo_view = (topo_view > 0.5).astype('float32')

            self.one_hot = one_hot
            if one_hot:
                one_hot = N.zeros((y.shape[0], 10), dtype='float32')
                for i in xrange(y.shape[0]):
                    one_hot[i, y[i]] = 1.
                y = one_hot

            m, r, c = topo_view.shape
            assert r == 28
            assert c == 28
            topo_view = topo_view.reshape(m, r, c, 1)

            if which_set == 'train':
                assert m == 60000
            elif which_set == 'test':
                assert m == 10000
            else:
                assert False

            if center:
                topo_view -= topo_view.mean(axis=0)

            if shuffle:
                self.shuffle_rng = np.random.RandomState([1, 2, 3])
                for i in xrange(topo_view.shape[0]):
                    j = self.shuffle_rng.randint(m)
                    # Copy ensures that memory is not aliased.
                    tmp = topo_view[i, :, :, :].copy()
                    topo_view[i, :, :, :] = topo_view[j, :, :, :]
                    topo_view[j, :, :, :] = tmp
                    # Note: slicing with i:i+1 works for both one_hot=True/False.
                    tmp = y[i:i + 1].copy()
                    y[i] = y[j]
                    y[j] = tmp

            view_converter = dense_design_matrix.DefaultViewConverter(
                (28, 28, 1))

            super(MNIST, self).__init__(topo_view=topo_view, y=y)

            assert not N.any(N.isnan(self.X))

            if start is not None:
                assert start >= 0
                if stop > self.X.shape[0]:
                    raise ValueError('stop=' + str(stop) + '>' + 'm=' +
                                     str(self.X.shape[0]))
                assert stop > start
                self.X = self.X[start:stop, :]
                if self.X.shape[0] != stop - start:
                    raise ValueError("X.shape[0]: %d. start: %d stop: %d" %
                                     (self.X.shape[0], start, stop))
                if len(self.y.shape) > 1:
                    self.y = self.y[start:stop, :]
                else:
                    self.y = self.y[start:stop]
                assert self.y.shape[0] == stop - start
        else:
            #data loading is disabled, just make something that defines the right topology
            topo = np.zeros((1, 28, 28, 1))
            super(MNIST, self).__init__(topo_view=topo)
            self.X = None
示例#4
0
    def __init__(self,
                 which_set,
                 center=False,
                 shuffle=False,
                 one_hot=False,
                 binarize=False,
                 start=None,
                 stop=None,
                 axes=['b', 0, 1, 'c'],
                 preprocessor=None,
                 fit_preprocessor=False,
                 fit_test_preprocessor=False):

        self.args = locals()

        if which_set not in ['train', 'test']:
            if which_set == 'valid':
                raise ValueError(
                    "There is no such thing as the MNIST validation set. MNIST"
                    "consists of 60,000 train examples and 10,000 test"
                    "examples. If you wish to use a validation set you should"
                    "divide the train set yourself. The pylearn2 dataset"
                    "implements and will only ever implement the standard"
                    "train / test split used in the literature.")
            raise ValueError('Unrecognized which_set value "%s".' %
                             (which_set, ) +
                             '". Valid values are ["train","test"].')

        def dimshuffle(b01c):
            default = ('b', 0, 1, 'c')
            return b01c.transpose(*[default.index(axis) for axis in axes])

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'
            # Path substitution done here in order to make the lower-level
            # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
            # the Deep Learning Tutorials, or in another package).
            im_path = serial.preprocess(im_path)
            label_path = serial.preprocess(label_path)
            topo_view = read_mnist_images(im_path, dtype='float32')
            y = read_mnist_labels(label_path)

            if binarize:
                topo_view = (topo_view > 0.5).astype('float32')

            self.one_hot = one_hot
            if one_hot:
                one_hot = N.zeros((y.shape[0], 10), dtype='float32')
                for i in xrange(y.shape[0]):
                    one_hot[i, y[i]] = 1.
                y = one_hot
                max_labels = None
            else:
                max_labels = 10

            m, r, c = topo_view.shape
            assert r == 28
            assert c == 28
            topo_view = topo_view.reshape(m, r, c, 1)

            if which_set == 'train':
                assert m == 60000
            elif which_set == 'test':
                assert m == 10000
            else:
                assert False

            if center:
                topo_view -= topo_view.mean(axis=0)

            if shuffle:
                self.shuffle_rng = make_np_rng(None, [1, 2, 3],
                                               which_method="shuffle")
                for i in xrange(topo_view.shape[0]):
                    j = self.shuffle_rng.randint(m)
                    # Copy ensures that memory is not aliased.
                    tmp = topo_view[i, :, :, :].copy()
                    topo_view[i, :, :, :] = topo_view[j, :, :, :]
                    topo_view[j, :, :, :] = tmp
                    # Note: slicing with i:i+1 works for one_hot=True/False
                    tmp = y[i:i + 1].copy()
                    y[i] = y[j]
                    y[j] = tmp

            super(MNIST, self).__init__(topo_view=dimshuffle(topo_view),
                                        y=y,
                                        axes=axes,
                                        max_labels=max_labels)

            assert not N.any(N.isnan(self.X))

            if start is not None:
                assert start >= 0
                if stop > self.X.shape[0]:
                    raise ValueError('stop=' + str(stop) + '>' + 'm=' +
                                     str(self.X.shape[0]))
                assert stop > start
                self.X = self.X[start:stop, :]
                if self.X.shape[0] != stop - start:
                    raise ValueError("X.shape[0]: %d. start: %d stop: %d" %
                                     (self.X.shape[0], start, stop))
                if len(self.y.shape) > 1:
                    self.y = self.y[start:stop, :]
                else:
                    self.y = self.y[start:stop]
                assert self.y.shape[0] == stop - start
        else:
            # data loading is disabled, just make something that defines the
            # right topology
            topo = dimshuffle(np.zeros((1, 28, 28, 1)))
            super(MNIST, self).__init__(topo_view=topo, axes=axes)
            self.X = None

        if which_set == 'test':
            assert fit_test_preprocessor is None or \
                (fit_preprocessor == fit_test_preprocessor)

        if self.X is not None and preprocessor:
            preprocessor.apply(self, fit_preprocessor)
示例#5
0
    def __init__(self, which_set, center = False, shuffle = False,
            one_hot = False, binarize = False, start = None,
            stop = None, axes=['b', 0, 1, 'c'],
            preprocessor = None,
            fit_preprocessor = False,
            fit_test_preprocessor = False):

        self.args = locals()


        if which_set not in ['train','test']:
            if which_set == 'valid':
                raise ValueError("There is no such thing as the MNIST "
"validation set. MNIST consists of 60,000 train examples and 10,000 test"
" examples. If you wish to use a validation set you should divide the train "
"set yourself. The pylearn2 dataset implements and will only ever implement "
"the standard train / test split used in the literature.")
            raise ValueError('Unrecognized which_set value "%s".' %
                    (which_set,)+'". Valid values are ["train","test"].')

        def dimshuffle(b01c):
            default = ('b', 0, 1, 'c')
            return b01c.transpose(*[default.index(axis) for axis in axes])

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'
            # Path substitution done here in order to make the lower-level
            # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
            # the Deep Learning Tutorials, or in another package).
            im_path = serial.preprocess(im_path)
            label_path = serial.preprocess(label_path)
            topo_view = read_mnist_images(im_path, dtype='float32')
            y = read_mnist_labels(label_path)

            if binarize:
                topo_view = ( topo_view > 0.5).astype('float32')

            self.one_hot = one_hot
            if one_hot:
                one_hot = N.zeros((y.shape[0],10),dtype='float32')
                for i in xrange(y.shape[0]):
                    one_hot[i,y[i]] = 1.
                y = one_hot

            m, r, c = topo_view.shape
            assert r == 28
            assert c == 28
            topo_view = topo_view.reshape(m,r,c,1)

            if which_set == 'train':
                assert m == 60000
            elif which_set == 'test':
                assert m == 10000
            else:
                assert False


            if center:
                topo_view -= topo_view.mean(axis=0)

            if shuffle:
                self.shuffle_rng = np.random.RandomState([1,2,3])
                for i in xrange(topo_view.shape[0]):
                    j = self.shuffle_rng.randint(m)
                    # Copy ensures that memory is not aliased.
                    tmp = topo_view[i,:,:,:].copy()
                    topo_view[i,:,:,:] = topo_view[j,:,:,:]
                    topo_view[j,:,:,:] = tmp
                    # Note: slicing with i:i+1 works for both one_hot=True/False.
                    tmp = y[i:i+1].copy()
                    y[i] = y[j]
                    y[j] = tmp


            super(MNIST,self).__init__(topo_view = dimshuffle(topo_view), y = y, axes=axes)

            assert not N.any(N.isnan(self.X))

            if start is not None:
                assert start >= 0
                if stop > self.X.shape[0]:
                    raise ValueError('stop='+str(stop)+'>'+'m='+str(self.X.shape[0]))
                assert stop > start
                self.X = self.X[start:stop,:]
                if self.X.shape[0] != stop - start:
                    raise ValueError("X.shape[0]: %d. start: %d stop: %d" % (self.X.shape[0], start, stop))
                if len(self.y.shape) > 1:
                    self.y = self.y[start:stop,:]
                else:
                    self.y = self.y[start:stop]
                assert self.y.shape[0] == stop - start
        else:
            # data loading is disabled, just make something that defines the
            # right topology
            topo = dimshuffle(np.zeros((1,28,28,1)))
            super(MNIST,self).__init__(topo_view = topo, axes=axes)
            self.X = None

        if which_set == 'test':
            assert fit_test_preprocessor is None or (fit_preprocessor == fit_test_preprocessor)

        if self.X is not None and preprocessor:
            preprocessor.apply(self, fit_preprocessor)
示例#6
0
    def __init__(self, which_set, center = False, shuffle = False,
            one_hot = False, binarize = False):

        if which_set not in ['train','test']:
            if which_set == 'valid':
                raise ValueError("There is no such thing as the MNIST "
"validation set. MNIST consists of 60,000 train examples and 10,000 test"
" examples. If you wish to use a validation set you should divide the train "
"set yourself. The pylearn2 dataset implements and will only ever implement "
"the standard train / test split used in the literature.")
            raise ValueError('Unrecognized which_set value "%s".' %
                    (which_set,)+'". Valid values are ["train","test"].')


        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'

            topo_view = read_mnist_images(im_path, dtype='float32')
            y = read_mnist_labels(label_path)

            if binarize:
                topo_view = ( topo_view > 0.5).astype('float32')

            self.one_hot = one_hot
            if one_hot:
                one_hot = N.zeros((y.shape[0],10),dtype='float32')
                for i in xrange(y.shape[0]):
                    one_hot[i,y[i]] = 1.
                y = one_hot

            m, r, c = topo_view.shape
            assert r == 28
            assert c == 28
            topo_view = topo_view.reshape(m,r,c,1)

            if which_set == 'train':
                assert m == 60000
            elif which_set == 'test':
                assert m == 10000
            else:
                assert False


            if center:
                topo_view -= topo_view.mean(axis=0)

            if shuffle:
                self.shuffle_rng = np.random.RandomState([1,2,3])
                for i in xrange(topo_view.shape[0]):
                    j = self.shuffle_rng.randint(m)
                    tmp = topo_view[i,:,:,:]
                    topo_view[i,:,:,:] = topo_view[j,:,:,:]
                    topo_view[j,:,:,:] = tmp
                    tmp = y[i]
                    y[i] = y[j]
                    y[j] = tmp

            view_converter = dense_design_matrix.DefaultViewConverter((28,28,1))

            super(MNIST,self).__init__(topo_view = topo_view , y = y)

            assert not N.any(N.isnan(self.X))
        else:
            #data loading is disabled, just make something that defines the right topology
            topo = np.zeros((1,28,28,1))
            super(MNIST,self).__init__(topo_view = topo)
            self.X = None
示例#7
0
    def __init__(self, which_set, center=False, shuffle=False,
                 binarize=False, start=None, stop=None,
                 axes=['b', 0, 1, 'c'],
                 preprocessor=None,
                 fit_preprocessor=False,
                 fit_test_preprocessor=False):
        self.args = locals()

        if which_set not in ['train', 'test']:
            if which_set == 'valid':
                raise ValueError(
                    "There is no such thing as the MNIST validation set. MNIST"
                    "consists of 60,000 train examples and 10,000 test"
                    "examples. If you wish to use a validation set you should"
                    "divide the train set yourself. The pylearn2 dataset"
                    "implements and will only ever implement the standard"
                    "train / test split used in the literature.")
            raise ValueError(
                'Unrecognized which_set value "%s".' % (which_set,) +
                '". Valid values are ["train","test"].')

        def dimshuffle(b01c):
            """
            .. todo::

                WRITEME
            """
            default = ('b', 0, 1, 'c')
            return b01c.transpose(*[default.index(axis) for axis in axes])

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/sign24/"
            if which_set == 'train':
                im_path = path + 'train-images-idx3-ubyte'
                label_path = path + 'train-labels-idx1-ubyte'
            else:
                assert which_set == 'test'
                im_path = path + 't10k-images-idx3-ubyte'
                label_path = path + 't10k-labels-idx1-ubyte'
            # Path substitution done here in order to make the lower-level
            # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
            # the Deep Learning Tutorials, or in another package).
            im_path = serial.preprocess(im_path)
            label_path = serial.preprocess(label_path)

            # Locally cache the files before reading them
            datasetCache = cache.datasetCache
            im_path = datasetCache.cache_file(im_path)
            label_path = datasetCache.cache_file(label_path)

            topo_view = read_mnist_images(im_path, dtype='float32')
            y = np.atleast_2d(read_mnist_labels(label_path)).T
        else:
            if which_set == 'train':
                size = 15
            elif which_set == 'test':
                size = 5
            else:
                raise ValueError(
                    'Unrecognized which_set value "%s".' % (which_set,) +
                    '". Valid values are ["train","test"].')
            topo_view = np.random.rand(size, 28, 28)
            y = np.random.randint(0, 10, (size, 1))

        if binarize:
            topo_view = (topo_view > 0.5).astype('float32')

        y_labels = 24

        m, r, c = topo_view.shape
        assert r == 28
        assert c == 28
        topo_view = topo_view.reshape(m, r, c, 1)

        if which_set == 'train':
            assert m == 3576
        elif which_set == 'test':
            assert m == 1176
        else:
            assert False

        if center:
            topo_view -= topo_view.mean(axis=0)

        if shuffle:
            self.shuffle_rng = make_np_rng(
                None, [1, 2, 3], which_method="shuffle")
            for i in xrange(topo_view.shape[0]):
                j = self.shuffle_rng.randint(m)
                # Copy ensures that memory is not aliased.
                tmp = topo_view[i, :, :, :].copy()
                topo_view[i, :, :, :] = topo_view[j, :, :, :]
                topo_view[j, :, :, :] = tmp

                tmp = y[i:i + 1].copy()
                y[i] = y[j]
                y[j] = tmp

        super(MNIST, self).__init__(topo_view=dimshuffle(topo_view), y=y,
                                    axes=axes, y_labels=y_labels)

        assert not N.any(N.isnan(self.X))

        if start is not None:
            assert start >= 0
            if stop > self.X.shape[0]:
                raise ValueError('stop=' + str(stop) + '>' +
                                 'm=' + str(self.X.shape[0]))
            assert stop > start
            self.X = self.X[start:stop, :]
            if self.X.shape[0] != stop - start:
                raise ValueError("X.shape[0]: %d. start: %d stop: %d"
                                 % (self.X.shape[0], start, stop))
            if len(self.y.shape) > 1:
                self.y = self.y[start:stop, :]
            else:
                self.y = self.y[start:stop]
            assert self.y.shape[0] == stop - start

        if which_set == 'test':
            assert fit_test_preprocessor is None or \
                (fit_preprocessor == fit_test_preprocessor)

        if self.X is not None and preprocessor:
            preprocessor.apply(self, fit_preprocessor)
示例#8
0
文件: mnist.py 项目: scyoyo/pylearn
    def __init__(self, which_set, center=False, shuffle=False, one_hot=False, binarize=False, start=None, stop=None):

        self.args = locals()

        if which_set not in ["train", "test"]:
            if which_set == "valid":
                raise ValueError(
                    "There is no such thing as the MNIST "
                    "validation set. MNIST consists of 60,000 train examples and 10,000 test"
                    " examples. If you wish to use a validation set you should divide the train "
                    "set yourself. The pylearn2 dataset implements and will only ever implement "
                    "the standard train / test split used in the literature."
                )
            raise ValueError(
                'Unrecognized which_set value "%s".' % (which_set,) + '". Valid values are ["train","test"].'
            )

        if control.get_load_data():
            path = "${PYLEARN2_DATA_PATH}/mnist/"
            if which_set == "train":
                im_path = path + "train-images-idx3-ubyte"
                label_path = path + "train-labels-idx1-ubyte"
            else:
                assert which_set == "test"
                im_path = path + "t10k-images-idx3-ubyte"
                label_path = path + "t10k-labels-idx1-ubyte"
            # Path substitution done here in order to make the lower-level
            # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
            # the Deep Learning Tutorials, or in another package).
            im_path = serial.preprocess(im_path)
            label_path = serial.preprocess(label_path)
            topo_view = read_mnist_images(im_path, dtype="float32")
            y = read_mnist_labels(label_path)

            if binarize:
                topo_view = (topo_view > 0.5).astype("float32")

            self.one_hot = one_hot
            if one_hot:
                one_hot = N.zeros((y.shape[0], 10), dtype="float32")
                for i in xrange(y.shape[0]):
                    one_hot[i, y[i]] = 1.0
                y = one_hot

            m, r, c = topo_view.shape
            assert r == 28
            assert c == 28
            topo_view = topo_view.reshape(m, r, c, 1)

            if which_set == "train":
                assert m == 60000
            elif which_set == "test":
                assert m == 10000
            else:
                assert False

            if center:
                topo_view -= topo_view.mean(axis=0)

            if shuffle:
                self.shuffle_rng = np.random.RandomState([1, 2, 3])
                for i in xrange(topo_view.shape[0]):
                    j = self.shuffle_rng.randint(m)
                    # Copy ensures that memory is not aliased.
                    tmp = topo_view[i, :, :, :].copy()
                    topo_view[i, :, :, :] = topo_view[j, :, :, :]
                    topo_view[j, :, :, :] = tmp
                    # Note: slicing with i:i+1 works for both one_hot=True/False.
                    tmp = y[i : i + 1].copy()
                    y[i] = y[j]
                    y[j] = tmp

            view_converter = dense_design_matrix.DefaultViewConverter((28, 28, 1))

            super(MNIST, self).__init__(topo_view=topo_view, y=y)

            assert not N.any(N.isnan(self.X))

            if start is not None:
                assert start >= 0
                if stop > self.X.shape[0]:
                    raise ValueError("stop=" + str(stop) + ">" + "m=" + str(self.X.shape[0]))
                assert stop > start
                self.X = self.X[start:stop, :]
                if self.X.shape[0] != stop - start:
                    raise ValueError("X.shape[0]: %d. start: %d stop: %d" % (self.X.shape[0], start, stop))
                if len(self.y.shape) > 1:
                    self.y = self.y[start:stop, :]
                else:
                    self.y = self.y[start:stop]
                assert self.y.shape[0] == stop - start
        else:
            # data loading is disabled, just make something that defines the right topology
            topo = np.zeros((1, 28, 28, 1))
            super(MNIST, self).__init__(topo_view=topo)
            self.X = None
    def __init__(self, which_set, pos_class_digit=7, neg_class_digit=9, center=False, shuffle=False,
                 one_hot=None, binarize=False, start=None, stop=None, axes=['b', 0, 1, 'c'],
                 preprocessor=None, fit_preprocessor=False, fit_test_preprocessor=False,
                 X_aug=None, Y_aug=None, clip_size=None, labeler_ai=None, balance_classes=True):
        self.args = locals()
        if type(pos_class_digit) is list:
            raise ValueError(
                "binary_mnist allows multiple digits in the negative class,"
                "but only one digit (not a list!) in the positive class."
                )
                

        if which_set not in ['train', 'test']:
            if which_set == 'valid':
                raise ValueError(
                    "There is no such thing as the MNIST validation set. MNIST"
                    "consists of 60,000 train examples and 10,000 test"
                    "examples. If you wish to use a validation set you should"
                    "divide the train set yourself. The pylearn2 dataset"
                    "implements and will only ever implement the standard"
                    "train / test split used in the literature.")
            raise ValueError(
                'Unrecognized which_set value "%s".' % (which_set,) +
                '". Valid values are ["train","test"].')

        def dimshuffle(b01c):
            """
            .. todo::

                WRITEME
            """
            default = ('b', 0, 1, 'c')
            return b01c.transpose(*[default.index(axis) for axis in axes])

        path = "${PYLEARN2_DATA_PATH}/mnist/"
        if which_set == 'train':
            im_path = path + 'train-images-idx3-ubyte'
            label_path = path + 'train-labels-idx1-ubyte'
        else:
            assert which_set == 'test'
            im_path = path + 't10k-images-idx3-ubyte'
            label_path = path + 't10k-labels-idx1-ubyte'
        # Path substitution done here in order to make the lower-level
        # mnist_ubyte.py as stand-alone as possible (for reuse in, e.g.,
        # the Deep Learning Tutorials, or in another package).
        im_path = serial.preprocess(im_path)
        label_path = serial.preprocess(label_path)

        # Locally cache the files before reading them
        datasetCache = cache.datasetCache
        im_path = datasetCache.cache_file(im_path)
        label_path = datasetCache.cache_file(label_path)
        topo_view = read_mnist_images(im_path, dtype='float32')
        y = np.atleast_2d(read_mnist_labels(label_path)).T

        if clip_size:
            im_size = topo_view.shape[1:3]
            topo_view = topo_view[:, clip_size[0]:-clip_size[0], clip_size[1]:-clip_size[1]]

        if labeler_ai is not None:
            y = np.atleast_2d([labeler_ai(img) for img in topo_view]).T
            
        if X_aug is not None:
            topo_view = np.concatenate((X_aug, topo_view))
            y = np.concatenate((Y_aug, y))
            stop = stop+Y_aug.size

        
        # Divide the set into the positive class (a specific digit)
        # and a negative class (all other digits). Make sure there is
        # an equal number of examples of each class in the data.
        pos_ids, dummy = np.where(y==pos_class_digit)
        if type(y) is list:
            neg_ids, dummy = np.where(map(lambda yel: yel in neg_class_digit, y[0]))
        else:
            neg_ids, dummy = np.where(y==neg_class_digit)
        y[pos_ids] = 1
        y[neg_ids] = 0
        if balance_classes:
            ids_size = min(pos_ids.size, neg_ids.size)
            neg_ids = neg_ids[:ids_size]
            pos_ids = pos_ids[:ids_size]
            usable_ids = np.vstack((neg_ids, pos_ids)).reshape((-1), order='F') # Interleave the ids
            y = y[usable_ids]
            topo_view = topo_view[usable_ids,:,:]
        
        if binarize:
            topo_view = (topo_view > 0.5).astype('float32')

        max_labels = 2
        m, r, c = topo_view.shape
        topo_view = topo_view.reshape(m, r, c, 1)

        super(MNIST, self).__init__(topo_view=dimshuffle(topo_view), y=y,
                                    axes=axes, y_labels=max_labels)

        assert not N.any(N.isnan(self.X))
        if start is not None and stop is None:
            self.X = self.X[start:,:]
            self.y = self.y[start:]
        elif start is not None:
            assert start >= 0
            if stop > self.X.shape[0]:
                raise ValueError('stop=' + str(stop) + '>' +
                                 'm=' + str(self.X.shape[0]))
            assert stop > start
            self.X = self.X[start:stop, :]
            if self.X.shape[0] != stop - start:
                raise ValueError("X.shape[0]: %d. start: %d stop: %d"
                                 % (self.X.shape[0], start, stop))
            if len(self.y.shape) > 1:
                self.y = self.y[start:stop, :]
            else:
                self.y = self.y[start:stop]
            assert self.y.shape[0] == stop - start

        reshuffle_ids = np.random.permutation(self.y.size)
        self.y = y[reshuffle_ids]
        self.X = self.X[reshuffle_ids]
            
        if which_set == 'test':
            assert fit_test_preprocessor is None or \
                (fit_preprocessor == fit_test_preprocessor)

        if self.X is not None and preprocessor:
            preprocessor.apply(self, fit_preprocessor)