Ejemplo n.º 1
0
def get_data(data_name):
    if data_name == 'mnist':
        from fuel.datasets import MNIST

        img_size = (28, 28)

        data_train = MNIST(which_set="train", sources=['features'])
        data_valid = MNIST(which_set="test", sources=['features'])
        data_test = MNIST(which_set="test", sources=['features'])
    elif data_name == 'bmnist':
        from fuel.datasets.binarized_mnist import BinarizedMNIST

        img_size = (28, 28)

        data_train = BinarizedMNIST(which_set='train', sources=['features'])
        data_valid = BinarizedMNIST(which_set='valid', sources=['features'])
        data_test = BinarizedMNIST(which_set='test', sources=['features'])
    elif data_name == 'silhouettes':
        from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes

        size = 28
        img_size = (size, size)

        data_train = CalTech101Silhouettes(which_set=['train'],
                                           size=size,
                                           sources=['features'])
        data_valid = CalTech101Silhouettes(which_set=['valid'],
                                           size=size,
                                           sources=['features'])
        data_test = CalTech101Silhouettes(which_set=['test'],
                                          size=size,
                                          sources=['features'])
    elif data_name == 'tfd':
        from fuel.datasets.toronto_face_database import TorontoFaceDatabase

        size = 28
        img_size = (size, size)

        data_train = TorontoFaceDatabase(which_set=['unlabeled'],
                                         size=size,
                                         sources=['features'])
        data_valid = TorontoFaceDatabase(which_set=['valid'],
                                         size=size,
                                         sources=['features'])
        data_test = TorontoFaceDatabase(which_set=['test'],
                                        size=size,
                                        sources=['features'])

    elif data_name == 'speech':
        from SynthesisTaskData import SynthesisTaskData

        img_size = (28, 28)

        data_train = SynthesisTaskData(sources=['features'])
        data_valid = SynthesisTaskData(sources=['features'])
        data_test = SynthesisTaskData(sources=['features'])
    else:
        raise ValueError("Unknown dataset %s" % data_name)

    return img_size, data_train, data_valid, data_test
Ejemplo n.º 2
0
def main(name, epochs, batch_size, learning_rate, attention, n_iter, enc_dim,
         dec_dim, z_dim):

    if name is None:
        tag = "watt" if attention else "woatt"
        name = "%s-t%d-enc%d-dec%d-z%d" % (tag, n_iter, enc_dim, dec_dim,
                                           z_dim)

    print("\nRunning experiment %s" % name)
    print("         learning rate: %5.3f" % learning_rate)
    print("             attention: %s" % attention)
    print("          n_iterations: %d" % n_iter)
    print("     encoder dimension: %d" % enc_dim)
    print("           z dimension: %d" % z_dim)
    print("     decoder dimension: %d" % dec_dim)
    print()

    #------------------------------------------------------------------------

    x_dim = 28 * 28
    img_height, img_width = (28, 28)

    rnninits = {
        'weights_init': Orthogonal(),
        #'weights_init': IsotropicGaussian(0.001),
        'biases_init': Constant(0.),
    }
    inits = {
        'weights_init': Orthogonal(),
        #'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }

    prior_mu = T.zeros([z_dim])
    prior_log_sigma = T.zeros([z_dim])

    if attention:
        read_N = 4
        write_N = 6
        read_dim = 2 * read_N**2

        reader = AttentionReader(x_dim=x_dim,
                                 dec_dim=dec_dim,
                                 width=img_width,
                                 height=img_height,
                                 N=read_N,
                                 **inits)
        writer = AttentionWriter(input_dim=dec_dim,
                                 output_dim=x_dim,
                                 width=img_width,
                                 height=img_height,
                                 N=read_N,
                                 **inits)
    else:
        read_dim = 2 * x_dim

        reader = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
        writer = Writer(input_dim=dec_dim, output_dim=x_dim, **inits)

    encoder = LSTM(dim=enc_dim, name="RNN_enc", **rnninits)
    decoder = LSTM(dim=dec_dim, name="RNN_dec", **rnninits)
    encoder_mlp = MLP([Tanh()], [(read_dim + dec_dim), 4 * enc_dim],
                      name="MLP_enc",
                      **inits)
    decoder_mlp = MLP([Tanh()], [z_dim, 4 * dec_dim], name="MLP_dec", **inits)
    q_sampler = Qsampler(input_dim=enc_dim, output_dim=z_dim, **inits)

    for brick in [
            reader, writer, encoder, decoder, encoder_mlp, decoder_mlp,
            q_sampler
    ]:
        brick.allocate()
        brick.initialize()

    #------------------------------------------------------------------------
    x = tensor.matrix('features')

    # This is one iteration
    def one_iteration(c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec,
                      x):
        x_hat = x - T.nnet.sigmoid(c)
        r = reader.apply(x, x_hat, h_dec)
        i_enc = encoder_mlp.apply(T.concatenate([r, h_dec], axis=1))
        h_enc, c_enc = encoder.apply(states=h_enc,
                                     cells=c_enc,
                                     inputs=i_enc,
                                     iterate=False)
        z_mean, z_log_sigma, z = q_sampler.apply(h_enc)
        i_dec = decoder_mlp.apply(z)
        h_dec, c_dec = decoder.apply(states=h_dec,
                                     cells=c_dec,
                                     inputs=i_dec,
                                     iterate=False)
        c = c + writer.apply(h_dec)
        return c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec

    outputs_info = [
        T.zeros([batch_size, x_dim]),  # c
        T.zeros([batch_size, enc_dim]),  # h_enc
        T.zeros([batch_size, enc_dim]),  # c_enc
        T.zeros([batch_size, z_dim]),  # z_mean
        T.zeros([batch_size, z_dim]),  # z_log_sigma
        T.zeros([batch_size, z_dim]),  # z
        T.zeros([batch_size, dec_dim]),  # h_dec
        T.zeros([batch_size, dec_dim]),  # c_dec
    ]

    outputs, scan_updates = theano.scan(fn=one_iteration,
                                        sequences=[],
                                        outputs_info=outputs_info,
                                        non_sequences=[x],
                                        n_steps=n_iter)

    c, h_enc, c_enc, z_mean, z_log_sigma, z, h_dec, c_dec = outputs

    kl_terms = (prior_log_sigma - z_log_sigma + 0.5 *
                (tensor.exp(2 * z_log_sigma) +
                 (z_mean - prior_mu)**2) / tensor.exp(2 * prior_log_sigma) -
                0.5).sum(axis=-1)

    x_recons = T.nnet.sigmoid(c[-1, :, :])
    recons_term = BinaryCrossEntropy().apply(x, x_recons)
    recons_term.name = "recons_term"

    cost = recons_term + kl_terms.sum(axis=0).mean()
    cost.name = "nll_bound"

    #------------------------------------------------------------
    cg = ComputationGraph([cost])
    params = VariableFilter(roles=[PARAMETER])(cg.variables)

    algorithm = GradientDescent(
        cost=cost,
        params=params,
        step_rule=CompositeRule([
            #StepClipping(3.),
            Adam(learning_rate),
        ])
        #step_rule=RMSProp(learning_rate),
        #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95)
    )
    algorithm.add_updates(scan_updates)

    #------------------------------------------------------------------------
    # Setup monitors
    monitors = [cost]
    for t in range(n_iter):
        kl_term_t = kl_terms[t, :].mean()
        kl_term_t.name = "kl_term_%d" % t

        x_recons_t = T.nnet.sigmoid(c[t, :, :])
        recons_term_t = BinaryCrossEntropy().apply(x, x_recons_t)
        recons_term_t = recons_term_t.mean()
        recons_term_t.name = "recons_term_%d" % t

        monitors += [kl_term_t, recons_term_t]

    train_monitors = monitors[:]
    train_monitors += [aggregation.mean(algorithm.total_gradient_norm)]
    train_monitors += [aggregation.mean(algorithm.total_step_norm)]

    # Live plotting...
    plot_channels = [["train_nll_bound", "test_nll_bound"],
                     ["train_kl_term_%d" % t for t in range(n_iter)],
                     ["train_recons_term_%d" % t for t in range(n_iter)],
                     ["train_total_gradient_norm", "train_total_step_norm"]]

    #------------------------------------------------------------

    mnist_train = BinarizedMNIST("train", sources=['features'])
    mnist_test = BinarizedMNIST("test", sources=['features'])
    #mnist_train = MNIST("train", binary=True, sources=['features'])
    #mnist_test = MNIST("test", binary=True, sources=['features'])

    main_loop = MainLoop(
        model=None,
        data_stream=ForceFloatX(
            DataStream(mnist_train,
                       iteration_scheme=SequentialScheme(
                           mnist_train.num_examples, batch_size))),
        algorithm=algorithm,
        extensions=[
            Timing(),
            FinishAfter(after_n_epochs=epochs),
            DataStreamMonitoring(
                monitors,
                ForceFloatX(
                    DataStream(mnist_test,
                               iteration_scheme=SequentialScheme(
                                   mnist_test.num_examples, batch_size))),
                updates=scan_updates,
                prefix="test"),
            TrainingDataMonitoring(train_monitors,
                                   prefix="train",
                                   after_every_epoch=True),
            SerializeMainLoop(name + ".pkl"),
            Plot(name, channels=plot_channels),
            ProgressBar(),
            Printing()
        ])
    main_loop.run()
Ejemplo n.º 3
0
def main(name, epochs, batch_size, learning_rate, 
         attention, n_iter, enc_dim, dec_dim, z_dim):

    x_dim = 28*28
    img_height, img_width = (28, 28)
    
    rnninits = {
        #'weights_init': Orthogonal(),
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    inits = {
        #'weights_init': Orthogonal(),
        'weights_init': IsotropicGaussian(0.01),
        'biases_init': Constant(0.),
    }
    
    if attention != "":
        read_N, write_N = attention.split(',')
    
        read_N = int(read_N)
        write_N = int(write_N)
        read_dim = 2*read_N**2

        reader = AttentionReader(x_dim=x_dim, dec_dim=dec_dim,
                                 width=img_width, height=img_height,
                                 N=read_N, **inits)
        writer = AttentionWriter(input_dim=dec_dim, output_dim=x_dim,
                                 width=img_width, height=img_height,
                                 N=write_N, **inits)
        attention_tag = "r%d-w%d" % (read_N, write_N)
    else:
        read_dim = 2*x_dim

        reader = Reader(x_dim=x_dim, dec_dim=dec_dim, **inits)
        writer = Writer(input_dim=dec_dim, output_dim=x_dim, **inits)

        attention_tag = "full"

    #----------------------------------------------------------------------

    # Learning rate
    def lr_tag(value):
        """ Convert a float into a short tag-usable string representation. E.g.:
            0.1   -> 11
            0.01  -> 12
            0.001 -> 13
            0.005 -> 53
        """
        exp = np.floor(np.log10(value))
        leading = ("%e"%value)[0]
        return "%s%d" % (leading, -exp)

    lr_str = lr_tag(learning_rate)
    name = "%s-%s-t%d-enc%d-dec%d-z%d-lr%s" % (name, attention_tag, n_iter, enc_dim, dec_dim, z_dim, lr_str)

    print("\nRunning experiment %s" % name)
    print("         learning rate: %5.3f" % learning_rate) 
    print("             attention: %s" % attention)
    print("          n_iterations: %d" % n_iter)
    print("     encoder dimension: %d" % enc_dim)
    print("           z dimension: %d" % z_dim)
    print("     decoder dimension: %d" % dec_dim)
    print()

    #----------------------------------------------------------------------

    encoder_rnn = LSTM(dim=enc_dim, name="RNN_enc", **rnninits)
    decoder_rnn = LSTM(dim=dec_dim, name="RNN_dec", **rnninits)
    encoder_mlp = MLP([Identity()], [(read_dim+dec_dim), 4*enc_dim], name="MLP_enc", **inits)
    decoder_mlp = MLP([Identity()], [             z_dim, 4*dec_dim], name="MLP_dec", **inits)
    q_sampler = Qsampler(input_dim=enc_dim, output_dim=z_dim, **inits)

    draw = DrawModel(
                n_iter, 
                reader=reader,
                encoder_mlp=encoder_mlp,
                encoder_rnn=encoder_rnn,
                sampler=q_sampler,
                decoder_mlp=decoder_mlp,
                decoder_rnn=decoder_rnn,
                writer=writer)
    draw.initialize()


    #------------------------------------------------------------------------
    x = tensor.matrix('features')
    
    #x_recons = 1. + x
    x_recons, kl_terms = draw.reconstruct(x)
    #x_recons, _, _, _, _ = draw.silly(x, n_steps=10, batch_size=100)
    #x_recons = x_recons[-1,:,:]

    #samples = draw.sample(100) 
    #x_recons = samples[-1, :, :]
    #x_recons = samples[-1, :, :]

    recons_term = BinaryCrossEntropy().apply(x, x_recons)
    recons_term.name = "recons_term"

    cost = recons_term + kl_terms.sum(axis=0).mean()
    cost.name = "nll_bound"

    #------------------------------------------------------------
    cg = ComputationGraph([cost])
    params = VariableFilter(roles=[PARAMETER])(cg.variables)

    algorithm = GradientDescent(
        cost=cost, 
        params=params,
        step_rule=CompositeRule([
            StepClipping(10.), 
            Adam(learning_rate),
        ])
        #step_rule=RMSProp(learning_rate),
        #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95)
    )
    #algorithm.add_updates(scan_updates)


    #------------------------------------------------------------------------
    # Setup monitors
    monitors = [cost]
    for t in range(n_iter):
        kl_term_t = kl_terms[t,:].mean()
        kl_term_t.name = "kl_term_%d" % t

        #x_recons_t = T.nnet.sigmoid(c[t,:,:])
        #recons_term_t = BinaryCrossEntropy().apply(x, x_recons_t)
        #recons_term_t = recons_term_t.mean()
        #recons_term_t.name = "recons_term_%d" % t

        monitors +=[kl_term_t]

    train_monitors = monitors[:]
    train_monitors += [aggregation.mean(algorithm.total_gradient_norm)]
    train_monitors += [aggregation.mean(algorithm.total_step_norm)]
    # Live plotting...
    plot_channels = [
        ["train_nll_bound", "test_nll_bound"],
        ["train_kl_term_%d" % t for t in range(n_iter)],
        #["train_recons_term_%d" % t for t in range(n_iter)],
        ["train_total_gradient_norm", "train_total_step_norm"]
    ]

    #------------------------------------------------------------

    mnist_train = BinarizedMNIST("train", sources=['features'])
    mnist_valid = BinarizedMNIST("valid", sources=['features'])
    mnist_test = BinarizedMNIST("test", sources=['features'])

    train_stream = DataStream(mnist_train, iteration_scheme=SequentialScheme(mnist_train.num_examples, batch_size))
    valid_stream = DataStream(mnist_valid, iteration_scheme=SequentialScheme(mnist_valid.num_examples, batch_size))
    test_stream  = DataStream(mnist_test,  iteration_scheme=SequentialScheme(mnist_test.num_examples, batch_size))

    main_loop = MainLoop(
        model=Model(cost),
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            Timing(),
            FinishAfter(after_n_epochs=epochs),
            TrainingDataMonitoring(
                train_monitors, 
                prefix="train",
                after_epoch=True),
#            DataStreamMonitoring(
#                monitors,
#                valid_stream,
##                updates=scan_updates, 
#                prefix="valid"),
            DataStreamMonitoring(
                monitors,
                test_stream,
#                updates=scan_updates, 
                prefix="test"),
            Checkpoint(name+".pkl", after_epoch=True, save_separately=['log', 'model']),
            #Dump(name),
            Plot(name, channels=plot_channels),
            ProgressBar(),
            Printing()])
    main_loop.run()
Ejemplo n.º 4
0
def get_data(data_name):
    if data_name == 'bmnist':
        from fuel.datasets.binarized_mnist import BinarizedMNIST

        x_dim = 28 * 28

        data_train = BinarizedMNIST(which_sets=['train'], sources=['features'])
        data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features'])
        data_test = BinarizedMNIST(which_sets=['test'], sources=['features'])
    elif data_name == 'mnist':
        from fuel.datasets.mnist import MNIST

        x_dim = 28 * 28

        data_train = MNIST(which_sets=['train'], sources=['features'])
        data_valid = MNIST(which_sets=['test'], sources=['features'])
        data_test = MNIST(which_sets=['test'], sources=['features'])
    elif data_name == 'silhouettes':
        from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes

        size = 28
        x_dim = size * size

        data_train = CalTech101Silhouettes(which_sets=['train'],
                                           size=size,
                                           sources=['features'])
        data_valid = CalTech101Silhouettes(which_sets=['valid'],
                                           size=size,
                                           sources=['features'])
        data_test = CalTech101Silhouettes(which_sets=['test'],
                                          size=size,
                                          sources=['features'])
    elif data_name == 'tfd':
        from fuel.datasets.toronto_face_database import TorontoFaceDatabase

        size = 48
        x_dim = size * size

        data_train = TorontoFaceDatabase(which_sets=['unlabeled'],
                                         size=size,
                                         sources=['features'])
        data_valid = TorontoFaceDatabase(which_sets=['valid'],
                                         size=size,
                                         sources=['features'])
        data_test = TorontoFaceDatabase(which_sets=['test'],
                                        size=size,
                                        sources=['features'])
    elif data_name == 'bars':
        from bars_data import Bars

        width = 4
        x_dim = width * width

        data_train = Bars(num_examples=5000, width=width, sources=['features'])
        data_valid = Bars(num_examples=5000, width=width, sources=['features'])
        data_test = Bars(num_examples=5000, width=width, sources=['features'])
    elif data_name in local_datasets:
        from fuel.datasets.hdf5 import H5PYDataset

        fname = "data/" + data_name + ".hdf5"

        data_train = H5PYDataset(fname,
                                 which_sets=["train"],
                                 sources=['features'],
                                 load_in_memory=True)
        data_valid = H5PYDataset(fname,
                                 which_sets=["valid"],
                                 sources=['features'],
                                 load_in_memory=True)
        data_test = H5PYDataset(fname,
                                which_sets=["test"],
                                sources=['features'],
                                load_in_memory=True)

        some_features = data_train.get_data(None, slice(0, 100))[0]
        assert some_features.shape[0] == 100

        some_features = some_features.reshape([100, -1])
        x_dim = some_features.shape[1]
    else:
        raise ValueError("Unknown dataset %s" % data_name)

    return x_dim, data_train, data_valid, data_test
Ejemplo n.º 5
0
def get_data(data_name):
    if data_name == 'bmnist':
        from fuel.datasets.binarized_mnist import BinarizedMNIST

        x_dim = 28*28 

        data_train = BinarizedMNIST(which_sets=['train'], sources=['features'])
        data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features'])
        data_test  = BinarizedMNIST(which_sets=['test'], sources=['features'])
    elif data_name == 'mnist':
        from fuel.datasets.mnist import MNIST

        x_dim = 28*28 

        data_train = MNIST(which_sets=['train'], sources=['features'])
        data_valid = MNIST(which_sets=['test'], sources=['features'])
        data_test  = MNIST(which_sets=['test'], sources=['features'])
    elif data_name == 'silhouettes':
        from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes

        size = 28
        x_dim = size*size

        data_train = CalTech101Silhouettes(which_sets=['train'], size=size, sources=['features'])
        data_valid = CalTech101Silhouettes(which_sets=['valid'], size=size, sources=['features'])
        data_test  = CalTech101Silhouettes(which_sets=['test'], size=size, sources=['features'])
    elif data_name == 'tfd':
        from fuel.datasets.toronto_face_database import TorontoFaceDatabase

        size = 48
        x_dim = size*size

        data_train = TorontoFaceDatabase(which_sets=['unlabeled'], size=size, sources=['features'])
        data_valid = TorontoFaceDatabase(which_sets=['valid'], size=size, sources=['features'])
        data_test  = TorontoFaceDatabase(which_sets=['test'], size=size, sources=['features'])
    elif data_name == 'bars':
        from bars_data import Bars

        width = 4
        x_dim = width*width

        data_train = Bars(num_examples=5000, width=width, sources=['features'])
        data_valid = Bars(num_examples=5000, width=width, sources=['features'])
        data_test  = Bars(num_examples=5000, width=width, sources=['features'])
    elif data_name in local_datasets:
        from fuel.datasets.hdf5 import H5PYDataset

        fname = "data/"+data_name+".hdf5"
        
        data_train = H5PYDataset(fname, which_sets=["train"], sources=['features'], load_in_memory=True)
        data_valid = H5PYDataset(fname, which_sets=["valid"], sources=['features'], load_in_memory=True)
        data_test  = H5PYDataset(fname, which_sets=["test"], sources=['features'], load_in_memory=True)

        some_features = data_train.get_data(None, slice(0, 100))[0]
        assert some_features.shape[0] == 100 

        some_features = some_features.reshape([100, -1])
        x_dim = some_features.shape[1]
    else:
        raise ValueError("Unknown dataset %s" % data_name)

    return x_dim, data_train, data_valid, data_test
Ejemplo n.º 6
0
def main(name, epochs, batch_size, learning_rate):
    if name is None:
        name = "att-rw"

    print("\nRunning experiment %s" % name)
    print("         learning rate: %5.3f" % learning_rate)
    print()

    #------------------------------------------------------------------------

    img_height, img_width = 28, 28

    read_N = 12
    write_N = 14

    inits = {
        #'weights_init': Orthogonal(),
        'weights_init': IsotropicGaussian(0.001),
        'biases_init': Constant(0.),
    }

    x_dim = img_height * img_width

    reader = ZoomableAttentionWindow(img_height, img_width, read_N)
    writer = ZoomableAttentionWindow(img_height, img_width, write_N)

    # Parameterize the attention reader and writer
    mlpr = MLP(activations=[Tanh(), Identity()],
               dims=[x_dim, 50, 5],
               name="RMLP",
               **inits)
    mlpw = MLP(activations=[Tanh(), Identity()],
               dims=[x_dim, 50, 5],
               name="WMLP",
               **inits)

    # MLP between the reader and writer
    mlp = MLP(activations=[Tanh(), Identity()],
              dims=[read_N**2, 300, write_N**2],
              name="MLP",
              **inits)

    for brick in [mlpr, mlpw, mlp]:
        brick.allocate()
        brick.initialize()

    #------------------------------------------------------------------------
    x = tensor.matrix('features')

    hr = mlpr.apply(x)
    hw = mlpw.apply(x)

    center_y, center_x, delta, sigma, gamma = reader.nn2att(hr)
    r = reader.read(x, center_y, center_x, delta, sigma)

    h = mlp.apply(r)

    center_y, center_x, delta, sigma, gamma = writer.nn2att(hw)
    c = writer.write(h, center_y, center_x, delta, sigma) / gamma
    x_recons = T.nnet.sigmoid(c)

    cost = BinaryCrossEntropy().apply(x, x_recons)
    cost.name = "cost"

    #------------------------------------------------------------
    cg = ComputationGraph([cost])
    params = VariableFilter(roles=[PARAMETER])(cg.variables)

    algorithm = GradientDescent(
        cost=cost,
        params=params,
        step_rule=CompositeRule([
            RemoveNotFinite(),
            Adam(learning_rate),
            StepClipping(3.),
        ])
        #step_rule=RMSProp(learning_rate),
        #step_rule=Momentum(learning_rate=learning_rate, momentum=0.95)
    )

    #------------------------------------------------------------------------
    # Setup monitors
    monitors = [cost]
    #for v in [center_y, center_x, log_delta, log_sigma, log_gamma]:
    #    v_mean = v.mean()
    #    v_mean.name = v.name
    #    monitors += [v_mean]
    #    monitors += [aggregation.mean(v)]

    train_monitors = monitors[:]
    train_monitors += [aggregation.mean(algorithm.total_gradient_norm)]
    train_monitors += [aggregation.mean(algorithm.total_step_norm)]

    # Live plotting...
    plot_channels = [
        ["cost"],
    ]

    #------------------------------------------------------------

    mnist_train = BinarizedMNIST("train", sources=['features'])
    mnist_test = BinarizedMNIST("test", sources=['features'])
    #mnist_train = MNIST("train", binary=True, sources=['features'])
    #mnist_test = MNIST("test", binary=True, sources=['features'])

    main_loop = MainLoop(
        model=Model(cost),
        data_stream=ForceFloatX(
            DataStream(mnist_train,
                       iteration_scheme=SequentialScheme(
                           mnist_train.num_examples, batch_size))),
        algorithm=algorithm,
        extensions=[
            Timing(),
            FinishAfter(after_n_epochs=epochs),
            DataStreamMonitoring(
                monitors,
                ForceFloatX(
                    DataStream(mnist_test,
                               iteration_scheme=SequentialScheme(
                                   mnist_test.num_examples, batch_size))),
                prefix="test"),
            TrainingDataMonitoring(train_monitors,
                                   prefix="train",
                                   after_every_epoch=True),
            SerializeMainLoop(name + ".pkl"),
            #Plot(name, channels=plot_channels),
            ProgressBar(),
            Printing()
        ])
    main_loop.run()
Ejemplo n.º 7
0
def get_data(data_name):
    if data_name == 'mnist':
        from fuel.datasets import MNIST
        img_size = (28, 28)
        channels = 1
        data_train = MNIST(which_sets=["train"], sources=['features'])
        data_valid = MNIST(which_sets=["test"], sources=['features'])
        data_test = MNIST(which_sets=["test"], sources=['features'])
    elif data_name == 'bmnist':
        from fuel.datasets.binarized_mnist import BinarizedMNIST
        img_size = (28, 28)
        channels = 1
        data_train = BinarizedMNIST(which_sets=['train'], sources=['features'])
        data_valid = BinarizedMNIST(which_sets=['valid'], sources=['features'])
        data_test = BinarizedMNIST(which_sets=['test'], sources=['features'])
    # TODO: make a generic catch-all for loading custom datasets like "colormnist"
    elif data_name == 'colormnist':
        from draw.colormnist import ColorMNIST
        img_size = (28, 28)
        channels = 3
        data_train = ColorMNIST(which_sets=['train'], sources=['features'])
        data_valid = ColorMNIST(which_sets=['test'], sources=['features'])
        data_test = ColorMNIST(which_sets=['test'], sources=['features'])
    elif data_name == 'cifar10':
        from fuel.datasets.cifar10 import CIFAR10
        img_size = (32, 32)
        channels = 3
        data_train = CIFAR10(which_sets=['train'], sources=['features'])
        data_valid = CIFAR10(which_sets=['test'], sources=['features'])
        data_test = CIFAR10(which_sets=['test'], sources=['features'])
    elif data_name == 'svhn2':
        from fuel.datasets.svhn import SVHN
        img_size = (32, 32)
        channels = 3
        data_train = SVHN(which_format=2,
                          which_sets=['train'],
                          sources=['features'])
        data_valid = SVHN(which_format=2,
                          which_sets=['test'],
                          sources=['features'])
        data_test = SVHN(which_format=2,
                         which_sets=['test'],
                         sources=['features'])
    elif data_name == 'silhouettes':
        from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes
        size = 28
        img_size = (size, size)
        channels = 1
        data_train = CalTech101Silhouettes(which_sets=['train'],
                                           size=size,
                                           sources=['features'])
        data_valid = CalTech101Silhouettes(which_sets=['valid'],
                                           size=size,
                                           sources=['features'])
        data_test = CalTech101Silhouettes(which_sets=['test'],
                                          size=size,
                                          sources=['features'])
    elif data_name == 'tfd':
        from fuel.datasets.toronto_face_database import TorontoFaceDatabase
        img_size = (28, 28)
        channels = 1
        data_train = TorontoFaceDatabase(which_sets=['unlabeled'],
                                         size=size,
                                         sources=['features'])
        data_valid = TorontoFaceDatabase(which_sets=['valid'],
                                         size=size,
                                         sources=['features'])
        data_test = TorontoFaceDatabase(which_sets=['test'],
                                        size=size,
                                        sources=['features'])
    else:
        raise ValueError("Unknown dataset %s" % data_name)

    return img_size, channels, data_train, data_valid, data_test