Exemplo n.º 1
0
def reconstruct(model, nrows, ncols):
    dataset = BinarizedMNIST('valid')
    originals, = dataset.get_data(request=range(nrows * ncols))

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(originals[n].reshape((28, 28)), cmap=cm.Greys_r,
                  interpolation='nearest')

    draw = model.top_bricks[0]
    x = tensor.matrix('x')
    x_hat = draw.reconstruct(x)
    computation_graph = ComputationGraph([x_hat])
    f = theano.function([x], x_hat, updates=computation_graph.updates)
    reconstructions = f(originals)

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(reconstructions[n].reshape((28, 28)), cmap=cm.Greys_r,
                  interpolation='nearest')

    pyplot.show()
Exemplo n.º 2
0
def reconstruct(model, nrows, ncols):
    dataset = BinarizedMNIST('valid')
    originals, = dataset.get_data(request=range(nrows * ncols))

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(originals[n].reshape((28, 28)),
                  cmap=cm.Greys_r,
                  interpolation='nearest')

    draw = model.top_bricks[0]
    x = tensor.matrix('x')
    x_hat = draw.reconstruct(x)
    computation_graph = ComputationGraph([x_hat])
    f = theano.function([x], x_hat, updates=computation_graph.updates)
    reconstructions = f(originals)

    figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols)
    for n, (i, j) in enumerate(itertools.product(xrange(nrows),
                                                 xrange(ncols))):
        ax = axes[i][j]
        ax.axis('off')
        ax.imshow(reconstructions[n].reshape((28, 28)),
                  cmap=cm.Greys_r,
                  interpolation='nearest')

    pyplot.show()
Exemplo n.º 3
0
def test_binarized_mnist_no_split():
    skip_if_not_available(datasets=['binarized_mnist'])

    dataset = BinarizedMNIST()
    handle = dataset.open()
    data = dataset.get_data(handle, slice(0, 70000))[0]
    assert data.shape == (70000, 1, 28, 28)
    assert dataset.num_examples == 70000
    dataset.close(handle)
Exemplo n.º 4
0
def test_binarized_mnist_test():
    skip_if_not_available(datasets=['binarized_mnist'])

    mnist_test = BinarizedMNIST('test')
    handle = mnist_test.open()
    data = mnist_test.get_data(handle, slice(0, 10000))[0]
    assert data.shape == (10000, 1, 28, 28)
    assert mnist_test.num_examples == 10000
    mnist_test.close(handle)
Exemplo n.º 5
0
def load(batch_size, test_batch_size):
    tr_data = BinarizedMNIST(which_sets=('train', ))
    val_data = BinarizedMNIST(which_sets=('valid', ))
    test_data = BinarizedMNIST(which_sets=('test', ))

    ntrain = tr_data.num_examples
    nval = val_data.num_examples
    ntest = test_data.num_examples

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=test_batch_size)
    te_stream = DataStream(test_data, iteration_scheme=te_scheme)

    val_scheme = SequentialScheme(examples=nval, batch_size=batch_size)
    val_stream = DataStream(val_data, iteration_scheme=val_scheme)

    return _make_stream(tr_stream, batch_size), \
           _make_stream(val_stream, batch_size), \
           _make_stream(te_stream, test_batch_size)
Exemplo n.º 6
0
def test_mnist():
    skip_if_not_available(datasets=['binarized_mnist'])
    mnist_train = BinarizedMNIST('train')
    assert len(mnist_train.features) == 50000
    assert mnist_train.num_examples == 50000
    mnist_valid = BinarizedMNIST('valid')
    assert len(mnist_valid.features) == 10000
    assert mnist_valid.num_examples == 10000
    mnist_test = BinarizedMNIST('test')
    assert len(mnist_test.features) == 10000
    assert mnist_test.num_examples == 10000

    first_feature, = mnist_train.get_data(request=[0])
    assert first_feature.shape == (1, 784)
    assert first_feature.dtype.kind == 'f'

    assert_raises(ValueError, BinarizedMNIST, 'dummy')

    mnist_test = cPickle.loads(cPickle.dumps(mnist_test))
    assert len(mnist_test.features) == 10000

    mnist_test_unflattened = BinarizedMNIST('test', flatten=False)
    assert mnist_test_unflattened.features.shape == (10000, 28, 28)
Exemplo n.º 7
0
def test_binarized_mnist_test():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('test', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '0fa539ed8cb008880a61be77f744f06a'
    assert dataset.num_examples == 10000
    dataset.close(handle)
Exemplo n.º 8
0
def test_binarized_mnist_valid():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('valid', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '65e8099613162b3110a7618037011617'
    assert dataset.num_examples == 10000
    dataset.close(handle)
Exemplo n.º 9
0
def test_binarized_mnist_train():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('train', load_in_memory=False)
    handle = dataset.open()
    data, = dataset.get_data(handle, slice(0, 10))
    assert data.dtype == 'uint8'
    assert data.shape == (10, 1, 28, 28)
    assert hashlib.md5(data).hexdigest() == '0922fefc9a9d097e3b086b89107fafce'
    assert dataset.num_examples == 50000
    dataset.close(handle)
Exemplo n.º 10
0
def test_binarized_mnist_data_path():
    assert BinarizedMNIST('train').data_path == os.path.join(
        config.data_path, 'binarized_mnist.hdf5')
Exemplo n.º 11
0
def test_binarized_mnist_axes():
    skip_if_not_available(datasets=['binarized_mnist.hdf5'])

    dataset = BinarizedMNIST('train', load_in_memory=False)
    assert_equal(dataset.axis_labels['features'],
                 ('batch', 'channel', 'height', 'width'))
Exemplo n.º 12
0
def main(nvis, nhid, encoding_lstm_dim, decoding_lstm_dim, T=1):
    x = tensor.matrix('features')

    # Construct and initialize model
    encoding_mlp = MLP([Tanh()], [None, None])
    decoding_mlp = MLP([Tanh()], [None, None])
    encoding_lstm = LSTM(dim=encoding_lstm_dim)
    decoding_lstm = LSTM(dim=decoding_lstm_dim)
    draw = DRAW(nvis=nvis,
                nhid=nhid,
                T=T,
                encoding_mlp=encoding_mlp,
                decoding_mlp=decoding_mlp,
                encoding_lstm=encoding_lstm,
                decoding_lstm=decoding_lstm,
                biases_init=Constant(0),
                weights_init=Orthogonal())
    draw.push_initialization_config()
    encoding_lstm.weights_init = IsotropicGaussian(std=0.001)
    decoding_lstm.weights_init = IsotropicGaussian(std=0.001)
    draw.initialize()

    # Compute cost
    cost = -draw.log_likelihood_lower_bound(x).mean()
    cost.name = 'nll_upper_bound'
    model = Model(cost)

    # Datasets and data streams
    mnist_train = BinarizedMNIST('train')
    train_loop_stream = ForceFloatX(
        DataStream(dataset=mnist_train,
                   iteration_scheme=SequentialScheme(mnist_train.num_examples,
                                                     100)))
    train_monitor_stream = ForceFloatX(
        DataStream(dataset=mnist_train,
                   iteration_scheme=SequentialScheme(mnist_train.num_examples,
                                                     500)))
    mnist_valid = BinarizedMNIST('valid')
    valid_monitor_stream = ForceFloatX(
        DataStream(dataset=mnist_valid,
                   iteration_scheme=SequentialScheme(mnist_valid.num_examples,
                                                     500)))
    mnist_test = BinarizedMNIST('test')
    test_monitor_stream = ForceFloatX(
        DataStream(dataset=mnist_test,
                   iteration_scheme=SequentialScheme(mnist_test.num_examples,
                                                     500)))

    # Get parameters and monitoring channels
    computation_graph = ComputationGraph([cost])
    params = VariableFilter(roles=[PARAMETER])(computation_graph.variables)
    monitoring_channels = dict([
        ('avg_' + channel.tag.name, channel.mean())
        for channel in VariableFilter(
            name='.*term$')(computation_graph.auxiliary_variables)
    ])
    for name, channel in monitoring_channels.items():
        channel.name = name
    monitored_quantities = monitoring_channels.values() + [cost]

    # Training loop
    step_rule = RMSProp(learning_rate=1e-3, decay_rate=0.95)
    algorithm = GradientDescent(cost=cost, params=params, step_rule=step_rule)
    algorithm.add_updates(computation_graph.updates)
    main_loop = MainLoop(
        model=model,
        data_stream=train_loop_stream,
        algorithm=algorithm,
        extensions=[
            Timing(),
            SerializeMainLoop('vae.pkl', save_separately=['model']),
            FinishAfter(after_n_epochs=200),
            DataStreamMonitoring(monitored_quantities,
                                 train_monitor_stream,
                                 prefix="train",
                                 updates=computation_graph.updates),
            DataStreamMonitoring(monitored_quantities,
                                 valid_monitor_stream,
                                 prefix="valid",
                                 updates=computation_graph.updates),
            DataStreamMonitoring(monitored_quantities,
                                 test_monitor_stream,
                                 prefix="test",
                                 updates=computation_graph.updates),
            ProgressBar(),
            Printing()
        ])
    main_loop.run()
Exemplo n.º 13
0
if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # action = parser.add_mutually_exclusive_group()
    # action.add_argument('-t', '--train', help="Start training the model")
    # action.add_argument('-s', '--sample', help='Sample images from the trained model')
    #
    # parser.add_argument('--experiment', nargs=1, type=str,
    #     help="Change default location to run experiment")
    # parser.add_argument('--path', nargs=1, type=str,
    #     help="Change default location to save model")

    if dataset == 'mnist':
        data = MNIST(("train", ), sources=('features', ))
        data_test = MNIST(("test", ), sources=('features', ))
    elif dataset == 'binarized_mnist':
        data = BinarizedMNIST(("train", ), sources=('features', ))
        data_test = BinarizedMNIST(("test", ), sources=('features', ))
    elif dataset == "cifar10":
        data = CIFAR10(("train", ))
        data_test = CIFAR10(("test", ))
    else:
        pass  # Add CIFAR 10
    training_stream = DataStream(data,
                                 iteration_scheme=ShuffledScheme(
                                     data.num_examples, batch_size))
    test_stream = DataStream(data_test,
                             iteration_scheme=ShuffledScheme(
                                 data_test.num_examples, batch_size))
    logger.info("Dataset: {} loaded".format(dataset))

    if train: