def reconstruct(model, nrows, ncols): dataset = BinarizedMNIST('valid') originals, = dataset.get_data(request=range(nrows * ncols)) figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols) for n, (i, j) in enumerate(itertools.product(xrange(nrows), xrange(ncols))): ax = axes[i][j] ax.axis('off') ax.imshow(originals[n].reshape((28, 28)), cmap=cm.Greys_r, interpolation='nearest') draw = model.top_bricks[0] x = tensor.matrix('x') x_hat = draw.reconstruct(x) computation_graph = ComputationGraph([x_hat]) f = theano.function([x], x_hat, updates=computation_graph.updates) reconstructions = f(originals) figure, axes = pyplot.subplots(nrows=nrows, ncols=ncols) for n, (i, j) in enumerate(itertools.product(xrange(nrows), xrange(ncols))): ax = axes[i][j] ax.axis('off') ax.imshow(reconstructions[n].reshape((28, 28)), cmap=cm.Greys_r, interpolation='nearest') pyplot.show()
def test_binarized_mnist_no_split(): skip_if_not_available(datasets=['binarized_mnist']) dataset = BinarizedMNIST() handle = dataset.open() data = dataset.get_data(handle, slice(0, 70000))[0] assert data.shape == (70000, 1, 28, 28) assert dataset.num_examples == 70000 dataset.close(handle)
def test_binarized_mnist_test(): skip_if_not_available(datasets=['binarized_mnist']) mnist_test = BinarizedMNIST('test') handle = mnist_test.open() data = mnist_test.get_data(handle, slice(0, 10000))[0] assert data.shape == (10000, 1, 28, 28) assert mnist_test.num_examples == 10000 mnist_test.close(handle)
def load(batch_size, test_batch_size): tr_data = BinarizedMNIST(which_sets=('train', )) val_data = BinarizedMNIST(which_sets=('valid', )) test_data = BinarizedMNIST(which_sets=('test', )) ntrain = tr_data.num_examples nval = val_data.num_examples ntest = test_data.num_examples tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size) tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme) te_scheme = SequentialScheme(examples=ntest, batch_size=test_batch_size) te_stream = DataStream(test_data, iteration_scheme=te_scheme) val_scheme = SequentialScheme(examples=nval, batch_size=batch_size) val_stream = DataStream(val_data, iteration_scheme=val_scheme) return _make_stream(tr_stream, batch_size), \ _make_stream(val_stream, batch_size), \ _make_stream(te_stream, test_batch_size)
def test_mnist(): skip_if_not_available(datasets=['binarized_mnist']) mnist_train = BinarizedMNIST('train') assert len(mnist_train.features) == 50000 assert mnist_train.num_examples == 50000 mnist_valid = BinarizedMNIST('valid') assert len(mnist_valid.features) == 10000 assert mnist_valid.num_examples == 10000 mnist_test = BinarizedMNIST('test') assert len(mnist_test.features) == 10000 assert mnist_test.num_examples == 10000 first_feature, = mnist_train.get_data(request=[0]) assert first_feature.shape == (1, 784) assert first_feature.dtype.kind == 'f' assert_raises(ValueError, BinarizedMNIST, 'dummy') mnist_test = cPickle.loads(cPickle.dumps(mnist_test)) assert len(mnist_test.features) == 10000 mnist_test_unflattened = BinarizedMNIST('test', flatten=False) assert mnist_test_unflattened.features.shape == (10000, 28, 28)
def test_binarized_mnist_test(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('test', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '0fa539ed8cb008880a61be77f744f06a' assert dataset.num_examples == 10000 dataset.close(handle)
def test_binarized_mnist_valid(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('valid', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '65e8099613162b3110a7618037011617' assert dataset.num_examples == 10000 dataset.close(handle)
def test_binarized_mnist_train(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('train', load_in_memory=False) handle = dataset.open() data, = dataset.get_data(handle, slice(0, 10)) assert data.dtype == 'uint8' assert data.shape == (10, 1, 28, 28) assert hashlib.md5(data).hexdigest() == '0922fefc9a9d097e3b086b89107fafce' assert dataset.num_examples == 50000 dataset.close(handle)
def test_binarized_mnist_data_path(): assert BinarizedMNIST('train').data_path == os.path.join( config.data_path, 'binarized_mnist.hdf5')
def test_binarized_mnist_axes(): skip_if_not_available(datasets=['binarized_mnist.hdf5']) dataset = BinarizedMNIST('train', load_in_memory=False) assert_equal(dataset.axis_labels['features'], ('batch', 'channel', 'height', 'width'))
def main(nvis, nhid, encoding_lstm_dim, decoding_lstm_dim, T=1): x = tensor.matrix('features') # Construct and initialize model encoding_mlp = MLP([Tanh()], [None, None]) decoding_mlp = MLP([Tanh()], [None, None]) encoding_lstm = LSTM(dim=encoding_lstm_dim) decoding_lstm = LSTM(dim=decoding_lstm_dim) draw = DRAW(nvis=nvis, nhid=nhid, T=T, encoding_mlp=encoding_mlp, decoding_mlp=decoding_mlp, encoding_lstm=encoding_lstm, decoding_lstm=decoding_lstm, biases_init=Constant(0), weights_init=Orthogonal()) draw.push_initialization_config() encoding_lstm.weights_init = IsotropicGaussian(std=0.001) decoding_lstm.weights_init = IsotropicGaussian(std=0.001) draw.initialize() # Compute cost cost = -draw.log_likelihood_lower_bound(x).mean() cost.name = 'nll_upper_bound' model = Model(cost) # Datasets and data streams mnist_train = BinarizedMNIST('train') train_loop_stream = ForceFloatX( DataStream(dataset=mnist_train, iteration_scheme=SequentialScheme(mnist_train.num_examples, 100))) train_monitor_stream = ForceFloatX( DataStream(dataset=mnist_train, iteration_scheme=SequentialScheme(mnist_train.num_examples, 500))) mnist_valid = BinarizedMNIST('valid') valid_monitor_stream = ForceFloatX( DataStream(dataset=mnist_valid, iteration_scheme=SequentialScheme(mnist_valid.num_examples, 500))) mnist_test = BinarizedMNIST('test') test_monitor_stream = ForceFloatX( DataStream(dataset=mnist_test, iteration_scheme=SequentialScheme(mnist_test.num_examples, 500))) # Get parameters and monitoring channels computation_graph = ComputationGraph([cost]) params = VariableFilter(roles=[PARAMETER])(computation_graph.variables) monitoring_channels = dict([ ('avg_' + channel.tag.name, channel.mean()) for channel in VariableFilter( name='.*term$')(computation_graph.auxiliary_variables) ]) for name, channel in monitoring_channels.items(): channel.name = name monitored_quantities = monitoring_channels.values() + [cost] # Training loop step_rule = RMSProp(learning_rate=1e-3, decay_rate=0.95) algorithm = GradientDescent(cost=cost, params=params, step_rule=step_rule) algorithm.add_updates(computation_graph.updates) main_loop = MainLoop( model=model, data_stream=train_loop_stream, algorithm=algorithm, extensions=[ Timing(), SerializeMainLoop('vae.pkl', save_separately=['model']), FinishAfter(after_n_epochs=200), DataStreamMonitoring(monitored_quantities, train_monitor_stream, prefix="train", updates=computation_graph.updates), DataStreamMonitoring(monitored_quantities, valid_monitor_stream, prefix="valid", updates=computation_graph.updates), DataStreamMonitoring(monitored_quantities, test_monitor_stream, prefix="test", updates=computation_graph.updates), ProgressBar(), Printing() ]) main_loop.run()
if __name__ == '__main__': # parser = argparse.ArgumentParser() # action = parser.add_mutually_exclusive_group() # action.add_argument('-t', '--train', help="Start training the model") # action.add_argument('-s', '--sample', help='Sample images from the trained model') # # parser.add_argument('--experiment', nargs=1, type=str, # help="Change default location to run experiment") # parser.add_argument('--path', nargs=1, type=str, # help="Change default location to save model") if dataset == 'mnist': data = MNIST(("train", ), sources=('features', )) data_test = MNIST(("test", ), sources=('features', )) elif dataset == 'binarized_mnist': data = BinarizedMNIST(("train", ), sources=('features', )) data_test = BinarizedMNIST(("test", ), sources=('features', )) elif dataset == "cifar10": data = CIFAR10(("train", )) data_test = CIFAR10(("test", )) else: pass # Add CIFAR 10 training_stream = DataStream(data, iteration_scheme=ShuffledScheme( data.num_examples, batch_size)) test_stream = DataStream(data_test, iteration_scheme=ShuffledScheme( data_test.num_examples, batch_size)) logger.info("Dataset: {} loaded".format(dataset)) if train: