def get_data(data_name): if data_name == 'mnist': from fuel.datasets import MNIST img_size = (28, 28) data_train = MNIST(which_set="train", sources=['features']) data_valid = MNIST(which_set="test", sources=['features']) data_test = MNIST(which_set="test", sources=['features']) elif data_name == 'bmnist': from fuel.datasets.binarized_mnist import BinarizedMNIST img_size = (28, 28) data_train = BinarizedMNIST(which_set='train', sources=['features']) data_valid = BinarizedMNIST(which_set='valid', sources=['features']) data_test = BinarizedMNIST(which_set='test', sources=['features']) elif data_name == 'silhouettes': from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes size = 28 img_size = (size, size) data_train = CalTech101Silhouettes(which_set=['train'], size=size, sources=['features']) data_valid = CalTech101Silhouettes(which_set=['valid'], size=size, sources=['features']) data_test = CalTech101Silhouettes(which_set=['test'], size=size, sources=['features']) elif data_name == 'tfd': from fuel.datasets.toronto_face_database import TorontoFaceDatabase size = 28 img_size = (size, size) data_train = TorontoFaceDatabase(which_set=['unlabeled'], size=size, sources=['features']) data_valid = TorontoFaceDatabase(which_set=['valid'], size=size, sources=['features']) data_test = TorontoFaceDatabase(which_set=['test'], size=size, sources=['features']) elif data_name == 'speech': from SynthesisTaskData import SynthesisTaskData img_size = (28, 28) data_train = SynthesisTaskData(sources=['features']) data_valid = SynthesisTaskData(sources=['features']) data_test = SynthesisTaskData(sources=['features']) else: raise ValueError("Unknown dataset %s" % data_name) return img_size, data_train, data_valid, data_test
def main(name, epochs, batch_size, learning_rate, attention, n_iter, enc_dim, dec_dim, z_dim): if name is None: tag = "watt" if attention else "woatt" name = "%s-t%d-enc%d-dec%d-z%d" % (tag, n_iter, enc_dim, dec_dim, z_dim) print("\nRunning experiment %s" % name) print(" learning rate: %5.3f" % learning_rate) print(" attention: %s" % attention) print(" n_iterations: %d" % n_iter) print(" encoder dimension: %d" % enc_dim) print(" z dimension: %d" % z_dim) print(" decoder dimension: %d" % dec_dim) print() #------------------------------------------------------------------------ x_dim = 28 * 28 img_height, img_width = (28, 28) rnninits = { 'weights_init': Orthogonal(), #'weights_init': IsotropicGaussian(0.001), 'biases_init': Constant(0.), } inits = { 'weights_init': Orthogonal(), #'weights_init': IsotropicGaussian(0.01), 'biases_init': Constant(0.), } prior_mu = T.zeros([z_dim]) prior_log_sigma = T.zeros([z_dim]) if attention: read_N = 4 write_N = 6 read_dim = 2 * read_N**2 reader = AttentionReader(x_dim=x_dim, dec_dim=dec_dim, width=img_width, height=img_height, N=read_N, **inits) writer = AttentionWriter(input_dim=dec_dim, output_dim=x_dim, width=img_width, height=img_height, N=read_N, **inits) else: read_dim = 2 * x_dim reader = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits) writer = Writer(input_dim=dec_dim, output_dim=x_dim, **inits) encoder = LSTM(dim=enc_dim, name="RNN_enc", **rnninits) decoder = LSTM(dim=dec_dim, name="RNN_dec", **rnninits) encoder_mlp = MLP([Tanh()], [(read_dim + dec_dim), 4 * enc_dim], name="MLP_enc", **inits) decoder_mlp = MLP([Tanh()], [z_dim, 4 * dec_dim], name="MLP_dec", **inits) q_sampler = Qsampler(input_dim=enc_dim, output_dim=z_dim, **inits) for brick in [ reader, writer, encoder, decoder, encoder_mlp, decoder_mlp, q_sampler ]: brick.allocate() brick.initialize() #------------------------------------------------------------------------ x = tensor.matrix('features') # This is one iteration def one_iteration(c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec, x): x_hat = x - T.nnet.sigmoid(c) r = reader.apply(x, x_hat, h_dec) i_enc = encoder_mlp.apply(T.concatenate([r, h_dec], axis=1)) h_enc, c_enc = encoder.apply(states=h_enc, cells=c_enc, inputs=i_enc, iterate=False) z_mean, z_log_sigma, z = q_sampler.apply(h_enc) i_dec = decoder_mlp.apply(z) h_dec, c_dec = decoder.apply(states=h_dec, cells=c_dec, inputs=i_dec, iterate=False) c = c + writer.apply(h_dec) return c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec outputs_info = [ T.zeros([batch_size, x_dim]), # c T.zeros([batch_size, enc_dim]), # h_enc T.zeros([batch_size, enc_dim]), # c_enc T.zeros([batch_size, z_dim]), # z_mean T.zeros([batch_size, z_dim]), # z_log_sigma T.zeros([batch_size, z_dim]), # z T.zeros([batch_size, dec_dim]), # h_dec T.zeros([batch_size, dec_dim]), # c_dec ] outputs, scan_updates = theano.scan(fn=one_iteration, sequences=[], outputs_info=outputs_info, non_sequences=[x], n_steps=n_iter) c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec = outputs kl_terms = (prior_log_sigma - z_log_sigma + 0.5 * (tensor.exp(2 * z_log_sigma) + (z_mean - prior_mu)**2) / tensor.exp(2 * prior_log_sigma) - 0.5).sum(axis=-1) x_recons = T.nnet.sigmoid(c[-1, :, :]) recons_term = BinaryCrossEntropy().apply(x, x_recons) recons_term.name = "recons_term" cost = recons_term + kl_terms.sum(axis=0).mean() cost.name = "nll_bound" #------------------------------------------------------------ cg = ComputationGraph([cost]) params = VariableFilter(roles=[PARAMETER])(cg.variables) algorithm = GradientDescent( cost=cost, params=params, step_rule=CompositeRule([ #StepClipping(3.), Adam(learning_rate), ]) #step_rule=RMSProp(learning_rate), #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95) ) algorithm.add_updates(scan_updates) #------------------------------------------------------------------------ # Setup monitors monitors = [cost] for t in range(n_iter): kl_term_t = kl_terms[t, :].mean() kl_term_t.name = "kl_term_%d" % t x_recons_t = T.nnet.sigmoid(c[t, :, :]) recons_term_t = BinaryCrossEntropy().apply(x, x_recons_t) recons_term_t = recons_term_t.mean() recons_term_t.name = "recons_term_%d" % t monitors += [kl_term_t, recons_term_t] train_monitors = monitors[:] train_monitors += [aggregation.mean(algorithm.total_gradient_norm)] train_monitors += [aggregation.mean(algorithm.total_step_norm)] # Live plotting... plot_channels = [["train_nll_bound", "test_nll_bound"], ["train_kl_term_%d" % t for t in range(n_iter)], ["train_recons_term_%d" % t for t in range(n_iter)], ["train_total_gradient_norm", "train_total_step_norm"]] #------------------------------------------------------------ mnist_train = BinarizedMNIST("train", sources=['features']) mnist_test = BinarizedMNIST("test", sources=['features']) #mnist_train = MNIST("train", binary=True, sources=['features']) #mnist_test = MNIST("test", binary=True, sources=['features']) main_loop = MainLoop( model=None, data_stream=ForceFloatX( DataStream(mnist_train, iteration_scheme=SequentialScheme( mnist_train.num_examples, batch_size))), algorithm=algorithm, extensions=[ Timing(), FinishAfter(after_n_epochs=epochs), DataStreamMonitoring( monitors, ForceFloatX( DataStream(mnist_test, iteration_scheme=SequentialScheme( mnist_test.num_examples, batch_size))), updates=scan_updates, prefix="test"), TrainingDataMonitoring(train_monitors, prefix="train", after_every_epoch=True), SerializeMainLoop(name + ".pkl"), Plot(name, channels=plot_channels), ProgressBar(), Printing() ]) main_loop.run()
def main(name, epochs, batch_size, learning_rate, attention, n_iter, enc_dim, dec_dim, z_dim): x_dim = 28*28 img_height, img_width = (28, 28) rnninits = { #'weights_init': Orthogonal(), 'weights_init': IsotropicGaussian(0.01), 'biases_init': Constant(0.), } inits = { #'weights_init': Orthogonal(), 'weights_init': IsotropicGaussian(0.01), 'biases_init': Constant(0.), } if attention != "": read_N, write_N = attention.split(',') read_N = int(read_N) write_N = int(write_N) read_dim = 2*read_N**2 reader = AttentionReader(x_dim=x_dim, dec_dim=dec_dim, width=img_width, height=img_height, N=read_N, **inits) writer = AttentionWriter(input_dim=dec_dim, output_dim=x_dim, width=img_width, height=img_height, N=write_N, **inits) attention_tag = "r%d-w%d" % (read_N, write_N) else: read_dim = 2*x_dim reader = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits) writer = Writer(input_dim=dec_dim, output_dim=x_dim, **inits) attention_tag = "full" #---------------------------------------------------------------------- # Learning rate def lr_tag(value): """ Convert a float into a short tag-usable string representation. E.g.: 0.1 -> 11 0.01 -> 12 0.001 -> 13 0.005 -> 53 """ exp = np.floor(np.log10(value)) leading = ("%e"%value)[0] return "%s%d" % (leading, -exp) lr_str = lr_tag(learning_rate) name = "%s-%s-t%d-enc%d-dec%d-z%d-lr%s" % (name, attention_tag, n_iter, enc_dim, dec_dim, z_dim, lr_str) print("\nRunning experiment %s" % name) print(" learning rate: %5.3f" % learning_rate) print(" attention: %s" % attention) print(" n_iterations: %d" % n_iter) print(" encoder dimension: %d" % enc_dim) print(" z dimension: %d" % z_dim) print(" decoder dimension: %d" % dec_dim) print() #---------------------------------------------------------------------- encoder_rnn = LSTM(dim=enc_dim, name="RNN_enc", **rnninits) decoder_rnn = LSTM(dim=dec_dim, name="RNN_dec", **rnninits) encoder_mlp = MLP([Identity()], [(read_dim+dec_dim), 4*enc_dim], name="MLP_enc", **inits) decoder_mlp = MLP([Identity()], [ z_dim, 4*dec_dim], name="MLP_dec", **inits) q_sampler = Qsampler(input_dim=enc_dim, output_dim=z_dim, **inits) draw = DrawModel( n_iter, reader=reader, encoder_mlp=encoder_mlp, encoder_rnn=encoder_rnn, sampler=q_sampler, decoder_mlp=decoder_mlp, decoder_rnn=decoder_rnn, writer=writer) draw.initialize() #------------------------------------------------------------------------ x = tensor.matrix('features') #x_recons = 1. + x x_recons, kl_terms = draw.reconstruct(x) #x_recons, _, _, _, _ = draw.silly(x, n_steps=10, batch_size=100) #x_recons = x_recons[-1,:,:] #samples = draw.sample(100) #x_recons = samples[-1, :, :] #x_recons = samples[-1, :, :] recons_term = BinaryCrossEntropy().apply(x, x_recons) recons_term.name = "recons_term" cost = recons_term + kl_terms.sum(axis=0).mean() cost.name = "nll_bound" #------------------------------------------------------------ cg = ComputationGraph([cost]) params = VariableFilter(roles=[PARAMETER])(cg.variables) algorithm = GradientDescent( cost=cost, params=params, step_rule=CompositeRule([ StepClipping(10.), Adam(learning_rate), ]) #step_rule=RMSProp(learning_rate), #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95) ) #algorithm.add_updates(scan_updates) #------------------------------------------------------------------------ # Setup monitors monitors = [cost] for t in range(n_iter): kl_term_t = kl_terms[t,:].mean() kl_term_t.name = "kl_term_%d" % t #x_recons_t = T.nnet.sigmoid(c[t,:,:]) #recons_term_t = BinaryCrossEntropy().apply(x, x_recons_t) #recons_term_t = recons_term_t.mean() #recons_term_t.name = "recons_term_%d" % t monitors +=[kl_term_t] train_monitors = monitors[:] train_monitors += [aggregation.mean(algorithm.total_gradient_norm)] train_monitors += [aggregation.mean(algorithm.total_step_norm)] # Live plotting... plot_channels = [ ["train_nll_bound", "test_nll_bound"], ["train_kl_term_%d" % t for t in range(n_iter)], #["train_recons_term_%d" % t for t in range(n_iter)], ["train_total_gradient_norm", "train_total_step_norm"] ] #------------------------------------------------------------ mnist_train = BinarizedMNIST("train", sources=['features']) mnist_valid = BinarizedMNIST("valid", sources=['features']) mnist_test = BinarizedMNIST("test", sources=['features']) train_stream = DataStream(mnist_train, iteration_scheme=SequentialScheme(mnist_train.num_examples, batch_size)) valid_stream = DataStream(mnist_valid, iteration_scheme=SequentialScheme(mnist_valid.num_examples, batch_size)) test_stream = DataStream(mnist_test, iteration_scheme=SequentialScheme(mnist_test.num_examples, batch_size)) main_loop = MainLoop( model=Model(cost), data_stream=train_stream, algorithm=algorithm, extensions=[ Timing(), FinishAfter(after_n_epochs=epochs), TrainingDataMonitoring( train_monitors, prefix="train", after_epoch=True), # DataStreamMonitoring( # monitors, # valid_stream, ## updates=scan_updates, # prefix="valid"), DataStreamMonitoring( monitors, test_stream, # updates=scan_updates, prefix="test"), Checkpoint(name+".pkl", after_epoch=True, save_separately=['log', 'model']), #Dump(name), Plot(name, channels=plot_channels), ProgressBar(), Printing()]) main_loop.run()
def get_data(data_name): if data_name == 'bmnist': from fuel.datasets.binarized_mnist import BinarizedMNIST x_dim = 28 * 28 data_train = BinarizedMNIST(which_sets=['train'], sources=['features']) data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features']) data_test = BinarizedMNIST(which_sets=['test'], sources=['features']) elif data_name == 'mnist': from fuel.datasets.mnist import MNIST x_dim = 28 * 28 data_train = MNIST(which_sets=['train'], sources=['features']) data_valid = MNIST(which_sets=['test'], sources=['features']) data_test = MNIST(which_sets=['test'], sources=['features']) elif data_name == 'silhouettes': from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes size = 28 x_dim = size * size data_train = CalTech101Silhouettes(which_sets=['train'], size=size, sources=['features']) data_valid = CalTech101Silhouettes(which_sets=['valid'], size=size, sources=['features']) data_test = CalTech101Silhouettes(which_sets=['test'], size=size, sources=['features']) elif data_name == 'tfd': from fuel.datasets.toronto_face_database import TorontoFaceDatabase size = 48 x_dim = size * size data_train = TorontoFaceDatabase(which_sets=['unlabeled'], size=size, sources=['features']) data_valid = TorontoFaceDatabase(which_sets=['valid'], size=size, sources=['features']) data_test = TorontoFaceDatabase(which_sets=['test'], size=size, sources=['features']) elif data_name == 'bars': from bars_data import Bars width = 4 x_dim = width * width data_train = Bars(num_examples=5000, width=width, sources=['features']) data_valid = Bars(num_examples=5000, width=width, sources=['features']) data_test = Bars(num_examples=5000, width=width, sources=['features']) elif data_name in local_datasets: from fuel.datasets.hdf5 import H5PYDataset fname = "data/" + data_name + ".hdf5" data_train = H5PYDataset(fname, which_sets=["train"], sources=['features'], load_in_memory=True) data_valid = H5PYDataset(fname, which_sets=["valid"], sources=['features'], load_in_memory=True) data_test = H5PYDataset(fname, which_sets=["test"], sources=['features'], load_in_memory=True) some_features = data_train.get_data(None, slice(0, 100))[0] assert some_features.shape[0] == 100 some_features = some_features.reshape([100, -1]) x_dim = some_features.shape[1] else: raise ValueError("Unknown dataset %s" % data_name) return x_dim, data_train, data_valid, data_test
def get_data(data_name): if data_name == 'bmnist': from fuel.datasets.binarized_mnist import BinarizedMNIST x_dim = 28*28 data_train = BinarizedMNIST(which_sets=['train'], sources=['features']) data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features']) data_test = BinarizedMNIST(which_sets=['test'], sources=['features']) elif data_name == 'mnist': from fuel.datasets.mnist import MNIST x_dim = 28*28 data_train = MNIST(which_sets=['train'], sources=['features']) data_valid = MNIST(which_sets=['test'], sources=['features']) data_test = MNIST(which_sets=['test'], sources=['features']) elif data_name == 'silhouettes': from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes size = 28 x_dim = size*size data_train = CalTech101Silhouettes(which_sets=['train'], size=size, sources=['features']) data_valid = CalTech101Silhouettes(which_sets=['valid'], size=size, sources=['features']) data_test = CalTech101Silhouettes(which_sets=['test'], size=size, sources=['features']) elif data_name == 'tfd': from fuel.datasets.toronto_face_database import TorontoFaceDatabase size = 48 x_dim = size*size data_train = TorontoFaceDatabase(which_sets=['unlabeled'], size=size, sources=['features']) data_valid = TorontoFaceDatabase(which_sets=['valid'], size=size, sources=['features']) data_test = TorontoFaceDatabase(which_sets=['test'], size=size, sources=['features']) elif data_name == 'bars': from bars_data import Bars width = 4 x_dim = width*width data_train = Bars(num_examples=5000, width=width, sources=['features']) data_valid = Bars(num_examples=5000, width=width, sources=['features']) data_test = Bars(num_examples=5000, width=width, sources=['features']) elif data_name in local_datasets: from fuel.datasets.hdf5 import H5PYDataset fname = "data/"+data_name+".hdf5" data_train = H5PYDataset(fname, which_sets=["train"], sources=['features'], load_in_memory=True) data_valid = H5PYDataset(fname, which_sets=["valid"], sources=['features'], load_in_memory=True) data_test = H5PYDataset(fname, which_sets=["test"], sources=['features'], load_in_memory=True) some_features = data_train.get_data(None, slice(0, 100))[0] assert some_features.shape[0] == 100 some_features = some_features.reshape([100, -1]) x_dim = some_features.shape[1] else: raise ValueError("Unknown dataset %s" % data_name) return x_dim, data_train, data_valid, data_test
def main(name, epochs, batch_size, learning_rate): if name is None: name = "att-rw" print("\nRunning experiment %s" % name) print(" learning rate: %5.3f" % learning_rate) print() #------------------------------------------------------------------------ img_height, img_width = 28, 28 read_N = 12 write_N = 14 inits = { #'weights_init': Orthogonal(), 'weights_init': IsotropicGaussian(0.001), 'biases_init': Constant(0.), } x_dim = img_height * img_width reader = ZoomableAttentionWindow(img_height, img_width, read_N) writer = ZoomableAttentionWindow(img_height, img_width, write_N) # Parameterize the attention reader and writer mlpr = MLP(activations=[Tanh(), Identity()], dims=[x_dim, 50, 5], name="RMLP", **inits) mlpw = MLP(activations=[Tanh(), Identity()], dims=[x_dim, 50, 5], name="WMLP", **inits) # MLP between the reader and writer mlp = MLP(activations=[Tanh(), Identity()], dims=[read_N**2, 300, write_N**2], name="MLP", **inits) for brick in [mlpr, mlpw, mlp]: brick.allocate() brick.initialize() #------------------------------------------------------------------------ x = tensor.matrix('features') hr = mlpr.apply(x) hw = mlpw.apply(x) center_y, center_x, delta, sigma, gamma = reader.nn2att(hr) r = reader.read(x, center_y, center_x, delta, sigma) h = mlp.apply(r) center_y, center_x, delta, sigma, gamma = writer.nn2att(hw) c = writer.write(h, center_y, center_x, delta, sigma) / gamma x_recons = T.nnet.sigmoid(c) cost = BinaryCrossEntropy().apply(x, x_recons) cost.name = "cost" #------------------------------------------------------------ cg = ComputationGraph([cost]) params = VariableFilter(roles=[PARAMETER])(cg.variables) algorithm = GradientDescent( cost=cost, params=params, step_rule=CompositeRule([ RemoveNotFinite(), Adam(learning_rate), StepClipping(3.), ]) #step_rule=RMSProp(learning_rate), #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95) ) #------------------------------------------------------------------------ # Setup monitors monitors = [cost] #for v in [center_y, center_x, log_delta, log_sigma, log_gamma]: # v_mean = v.mean() # v_mean.name = v.name # monitors += [v_mean] # monitors += [aggregation.mean(v)] train_monitors = monitors[:] train_monitors += [aggregation.mean(algorithm.total_gradient_norm)] train_monitors += [aggregation.mean(algorithm.total_step_norm)] # Live plotting... plot_channels = [ ["cost"], ] #------------------------------------------------------------ mnist_train = BinarizedMNIST("train", sources=['features']) mnist_test = BinarizedMNIST("test", sources=['features']) #mnist_train = MNIST("train", binary=True, sources=['features']) #mnist_test = MNIST("test", binary=True, sources=['features']) main_loop = MainLoop( model=Model(cost), data_stream=ForceFloatX( DataStream(mnist_train, iteration_scheme=SequentialScheme( mnist_train.num_examples, batch_size))), algorithm=algorithm, extensions=[ Timing(), FinishAfter(after_n_epochs=epochs), DataStreamMonitoring( monitors, ForceFloatX( DataStream(mnist_test, iteration_scheme=SequentialScheme( mnist_test.num_examples, batch_size))), prefix="test"), TrainingDataMonitoring(train_monitors, prefix="train", after_every_epoch=True), SerializeMainLoop(name + ".pkl"), #Plot(name, channels=plot_channels), ProgressBar(), Printing() ]) main_loop.run()
def get_data(data_name): if data_name == 'mnist': from fuel.datasets import MNIST img_size = (28, 28) channels = 1 data_train = MNIST(which_sets=["train"], sources=['features']) data_valid = MNIST(which_sets=["test"], sources=['features']) data_test = MNIST(which_sets=["test"], sources=['features']) elif data_name == 'bmnist': from fuel.datasets.binarized_mnist import BinarizedMNIST img_size = (28, 28) channels = 1 data_train = BinarizedMNIST(which_sets=['train'], sources=['features']) data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features']) data_test = BinarizedMNIST(which_sets=['test'], sources=['features']) # TODO: make a generic catch-all for loading custom datasets like "colormnist" elif data_name == 'colormnist': from draw.colormnist import ColorMNIST img_size = (28, 28) channels = 3 data_train = ColorMNIST(which_sets=['train'], sources=['features']) data_valid = ColorMNIST(which_sets=['test'], sources=['features']) data_test = ColorMNIST(which_sets=['test'], sources=['features']) elif data_name == 'cifar10': from fuel.datasets.cifar10 import CIFAR10 img_size = (32, 32) channels = 3 data_train = CIFAR10(which_sets=['train'], sources=['features']) data_valid = CIFAR10(which_sets=['test'], sources=['features']) data_test = CIFAR10(which_sets=['test'], sources=['features']) elif data_name == 'svhn2': from fuel.datasets.svhn import SVHN img_size = (32, 32) channels = 3 data_train = SVHN(which_format=2, which_sets=['train'], sources=['features']) data_valid = SVHN(which_format=2, which_sets=['test'], sources=['features']) data_test = SVHN(which_format=2, which_sets=['test'], sources=['features']) elif data_name == 'silhouettes': from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes size = 28 img_size = (size, size) channels = 1 data_train = CalTech101Silhouettes(which_sets=['train'], size=size, sources=['features']) data_valid = CalTech101Silhouettes(which_sets=['valid'], size=size, sources=['features']) data_test = CalTech101Silhouettes(which_sets=['test'], size=size, sources=['features']) elif data_name == 'tfd': from fuel.datasets.toronto_face_database import TorontoFaceDatabase img_size = (28, 28) channels = 1 data_train = TorontoFaceDatabase(which_sets=['unlabeled'], size=size, sources=['features']) data_valid = TorontoFaceDatabase(which_sets=['valid'], size=size, sources=['features']) data_test = TorontoFaceDatabase(which_sets=['test'], size=size, sources=['features']) else: raise ValueError("Unknown dataset %s" % data_name) return img_size, channels, data_train, data_valid, data_test