Ejemplo n.º 1
0
def fit2():
    print("fit2")
    model = Model([
        nn.Dense(128, inshape=1, activation='relu'),
        nn.Dense(256, activation='relu'),
        nn.Dropout(0.80),
        nn.Dense(1)
    ])
    model.assemble()

    sess = Session(
        model,
        loss=losses.Mse(),
        optimizer=optimizers.Fixed(),
    )

    history = sess.fit(ds,
                       200000,
                       val_data=val_ds,
                       val_epochs=1000,
                       listeners=[
                           FitListener('val_end',
                                       callback=lambda h: on_val_end(sess, h))
                       ])

    fit_report(history, report_path + '02.png', 15)
Ejemplo n.º 2
0
    def setUpClass(cls):
        cls.model = Model(
            [Simplelayer(3, inshape=2),
             Simplelayer(4),
             Simplelayer(1)])

        cls.model.assemble()
Ejemplo n.º 3
0
    def setUpClass(cls):
        model = Model(
            [Simplelayer(4, inshape=3),
             Simplelayer(5),
             Simplelayer(1)])
        model.assemble()

        cls.sess = Session(model,
                           loss=losses.Mse(),
                           optimizer=optimizers.Fixed())
Ejemplo n.º 4
0
    def setUpClass(cls):
        model = Model([
                cnn.Conv2D(3, (3,3), inshape=(2,12,12)),
                cnn.MaxPool2D(),
                cnn.Conv2D(4, (5, 5), padding='valid'),
                nn.Flatten(),
                nn.Dense(10)
                ])
        model.assemble()

        cls.model = model
Ejemplo n.º 5
0
def fit_2():
    model = Model(
        [nnlys.Dense(32, inshape=1, activation='relu'),
         nnlys.Dense(1)])
    model.assemble()
    #使用了relu激活函数模型变成了非线性的, 增加了模型的复杂度
    history = train(30000, ds_0, model, record_epochs=300)
    history['loss'] = history['loss'][5:]
    history['val_loss'] = history['val_loss'][5:]
    history['epochs'] = history['epochs'][5:]
    fit_report(history, report_path + '02.png')
Ejemplo n.º 6
0
def fit_lstm():
    print("fit lstm")
    model = Model([
                rnn.Embedding(64, vocab_size+1),
                wrapper.Bidirectional(rnn.LSTM(128), rnn.LSTM(128)),
                nn.Filter(),
                nn.Dense(64),
                nn.Dropout(0.5),
                nn.Dense(1, activation='linear')
            ])
    model.assemble()
    fit('lstm', model)
Ejemplo n.º 7
0
def fit_gru():
    print("fit gru")
    model = Model([
                rnn.Embedding(64, vocab_size+1),
                wrapper.Bidirectional(rnn.GRU(64), rnn.GRU(64)),
                nn.Filter(),
                nn.Dense(64),
                nn.Dropout(0.5),
                nn.Dense(1, activation='linear')
            ])
    model.assemble()
    fit('gru', model)
Ejemplo n.º 8
0
def fit_gru():
    vocab_size = vocab.size()
    print("vocab size: ", vocab_size)
    model = Model([
                rnn.Embedding(256, vocab_size),
                rnn.GRU(1024, stateful=True),
                nn.Dense(1024),
                nn.Dropout(0.5),
                nn.Dense(vocab_size, activation='linear')
            ])

    model.assemble()
    fit("gru", model)
Ejemplo n.º 9
0
def fit_4():
    model = Model([
        nnlys.Dense(128, inshape=1, activation='relu'),
        nnlys.Dense(64, activation='relu'),
        nnlys.Dense(1)
    ])
    model.assemble()

    history = train(30000, ds_1, model, batch_size=16, record_epochs=300)
    history['loss'] = history['loss'][20:]
    history['val_loss'] = history['val_loss'][20:]
    history['epochs'] = history['epochs'][20:]
    fit_report(history, report_path + '04.png')
Ejemplo n.º 10
0
def fit(name, optimizer):
    inshape = ds_train.data.shape[1:]
    #pdb.set_trace()
    model = Model([
        cnn.Conv2D(32, (3, 3), inshape=inshape),
        cnn.MaxPool2D((2, 2), strides=(2, 2)),
        cnn.Conv2D(64, (3, 3)),
        cnn.MaxPool2D((2, 2), strides=(2, 2)),
        cnn.Conv2D(64, (3, 3)),
        nn.Flatten(),
        nn.Dense(64),
        nn.Dropout(0.5),
        nn.Dense(10)
    ])
    model.assemble()

    sess = Session(model,
                   loss=losses.CategoricalCrossentropy(),
                   optimizer=optimizer)

    stop_fit = session.condition_callback(lambda: sess.stop_fit(), 'val_loss',
                                          30)

    accuracy = lambda h: fit_tools.accuracy(sess, ds_test, h)

    def save_and_report(history):
        #pdb.set_trace()
        fit_tools.fit_report(history, report_path + name + ".png")
        model.save(model_path + name)

    #pdb.set_trace()
    history = sess.fit(ds_train,
                       200,
                       val_data=ds_test,
                       val_steps=100,
                       listeners=[
                           stop_fit,
                           session.FitListener('val_end', callback=accuracy),
                           session.FitListener('val_end',
                                               callback=save_and_report)
                       ])

    save_and_report(history)
Ejemplo n.º 11
0
def fit():
    inshape = ds_train.data.shape[1]
    model = Model([nn.Dense(10, inshape=inshape, activation='relu')])
    model.assemble()

    sess = Session(model,
                   loss=losses.CategoricalCrossentropy(),
                   optimizer=optimizers.Fixed(0.001))

    stop_fit = session.condition_callback(lambda: sess.stop_fit(), 'val_loss',
                                          10)

    #pdb.set_trace()
    history = sess.fit(ds_train,
                       20000,
                       val_epochs=5,
                       val_data=ds_test,
                       listeners=[
                           stop_fit,
                           session.FitListener('val_end', callback=accuracy)
                       ])

    fit_report(history, report_path + "0.png")
Ejemplo n.º 12
0
def fit_1():
    model = Model([nnlys.Dense(32, inshape=1), nnlys.Dense(1)])
    model.assemble()
    #这个模型是一个线性模型, 用来拟合非线性函数, 模型复杂度不够,一定会表现出欠拟合
    history = train(20000, ds_0, model, record_epochs=100)
    fit_report(history, report_path + '01.png')