def reconstruct(model, nrows, ncols): dataset = BinarizedMNIST('valid') originals, = dataset.get_data(request=range(nrows * ncols)) figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols) for n, (i, j) in enumerate(itertools.product(xrange(nrows), xrange(ncols))): ax = axes[i][j] ax.axis('off') ax.imshow(originals[n].reshape((28, 28)), cmap=cm.Greys_r, interpolation='nearest') draw = model.top_bricks[0] x = tensor.matrix('x') x_hat = draw.reconstruct(x) computation_graph = ComputationGraph([x_hat]) f = theano.function([x], x_hat, updates=computation_graph.updates) reconstructions = f(originals) figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols) for n, (i, j) in enumerate(itertools.product(xrange(nrows), xrange(ncols))): ax = axes[i][j] ax.axis('off') ax.imshow(reconstructions[n].reshape((28, 28)), cmap=cm.Greys_r, interpolation='nearest') pyplot.show()
def test_binarized_mnist_no_split(): skip_if_not_available(datasets=['binarized_mnist']) dataset = BinarizedMNIST() handle = dataset.open() data = dataset.get_data(handle, slice(0, 70000))[0] assert data.shape == (70000, 1, 28, 28) assert dataset.num_examples == 70000 dataset.close(handle)
def test_binarized_mnist_test(): skip_if_not_available(datasets=['binarized_mnist']) mnist_test = BinarizedMNIST('test') handle = mnist_test.open() data = mnist_test.get_data(handle, slice(0, 10000))[0] assert data.shape == (10000, 1, 28, 28) assert mnist_test.num_examples == 10000 mnist_test.close(handle)
def test_binarized_mnist_test(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('test', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '0fa539ed8cb008880a61be77f744f06a' assert dataset.num_examples == 10000 dataset.close(handle)
def test_binarized_mnist_valid(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('valid', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '65e8099613162b3110a7618037011617' assert dataset.num_examples == 10000 dataset.close(handle)
def test_binarized_mnist_train(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('train', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '0922fefc9a9d097e3b086b89107fafce' assert dataset.num_examples == 50000 dataset.close(handle)
def test_mnist(): skip_if_not_available(datasets=['binarized_mnist']) mnist_train = BinarizedMNIST('train') assert len(mnist_train.features) == 50000 assert mnist_train.num_examples == 50000 mnist_valid = BinarizedMNIST('valid') assert len(mnist_valid.features) == 10000 assert mnist_valid.num_examples == 10000 mnist_test = BinarizedMNIST('test') assert len(mnist_test.features) == 10000 assert mnist_test.num_examples == 10000 first_feature, = mnist_train.get_data(request=[0]) assert first_feature.shape == (1, 784) assert first_feature.dtype.kind == 'f' assert_raises(ValueError, BinarizedMNIST, 'dummy') mnist_test = cPickle.loads(cPickle.dumps(mnist_test)) assert len(mnist_test.features) == 10000 mnist_test_unflattened = BinarizedMNIST('test', flatten=False) assert mnist_test_unflattened.features.shape == (10000, 28, 28)