示例#1
0
def reconstruct(model, nrows, ncols):
    dataset = BinarizedMNIST('valid')
    originals, = dataset.get_data(request=range(nrows * ncols))

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(originals[n].reshape((28, 28)),
                  cmap=cm.Greys_r,
                  interpolation='nearest')

    draw = model.top_bricks[0]
    x = tensor.matrix('x')
    x_hat = draw.reconstruct(x)
    computation_graph = ComputationGraph([x_hat])
    f = theano.function([x], x_hat, updates=computation_graph.updates)
    reconstructions = f(originals)

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(reconstructions[n].reshape((28, 28)),
                  cmap=cm.Greys_r,
                  interpolation='nearest')

    pyplot.show()
示例#2
0
def reconstruct(model, nrows, ncols):
    dataset = BinarizedMNIST('valid')
    originals, = dataset.get_data(request=range(nrows * ncols))

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(originals[n].reshape((28, 28)), cmap=cm.Greys_r,
                  interpolation='nearest')

    draw = model.top_bricks[0]
    x = tensor.matrix('x')
    x_hat = draw.reconstruct(x)
    computation_graph = ComputationGraph([x_hat])
    f = theano.function([x], x_hat, updates=computation_graph.updates)
    reconstructions = f(originals)

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(reconstructions[n].reshape((28, 28)), cmap=cm.Greys_r,
                  interpolation='nearest')

    pyplot.show()
示例#3
0
def test_binarized_mnist_no_split():
    skip_if_not_available(datasets=['binarized_mnist'])

    dataset = BinarizedMNIST()
    handle = dataset.open()
    data = dataset.get_data(handle, slice(0, 70000))[0]
    assert data.shape == (70000, 1, 28, 28)
    assert dataset.num_examples == 70000
    dataset.close(handle)
示例#4
0
def test_binarized_mnist_test():
    skip_if_not_available(datasets=['binarized_mnist'])

    mnist_test = BinarizedMNIST('test')
    handle = mnist_test.open()
    data = mnist_test.get_data(handle, slice(0, 10000))[0]
    assert data.shape == (10000, 1, 28, 28)
    assert mnist_test.num_examples == 10000
    mnist_test.close(handle)
示例#5
0
def test_binarized_mnist_test():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('test', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '0fa539ed8cb008880a61be77f744f06a'
    assert dataset.num_examples == 10000
    dataset.close(handle)
示例#6
0
def test_binarized_mnist_valid():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('valid', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '65e8099613162b3110a7618037011617'
    assert dataset.num_examples == 10000
    dataset.close(handle)
示例#7
0
def test_binarized_mnist_train():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('train', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '0922fefc9a9d097e3b086b89107fafce'
    assert dataset.num_examples == 50000
    dataset.close(handle)
示例#8
0
def test_mnist():
    skip_if_not_available(datasets=['binarized_mnist'])
    mnist_train = BinarizedMNIST('train')
    assert len(mnist_train.features) == 50000
    assert mnist_train.num_examples == 50000
    mnist_valid = BinarizedMNIST('valid')
    assert len(mnist_valid.features) == 10000
    assert mnist_valid.num_examples == 10000
    mnist_test = BinarizedMNIST('test')
    assert len(mnist_test.features) == 10000
    assert mnist_test.num_examples == 10000

    first_feature, = mnist_train.get_data(request=[0])
    assert first_feature.shape == (1, 784)
    assert first_feature.dtype.kind == 'f'

    assert_raises(ValueError, BinarizedMNIST, 'dummy')

    mnist_test = cPickle.loads(cPickle.dumps(mnist_test))
    assert len(mnist_test.features) == 10000

    mnist_test_unflattened = BinarizedMNIST('test', flatten=False)
    assert mnist_test_unflattened.features.shape == (10000, 28, 28)