コード例 #1
0
def fit2():
    print("fit2")
    model = Model([
        nn.Dense(128, inshape=1, activation='relu'),
        nn.Dense(256, activation='relu'),
        nn.Dropout(0.80),
        nn.Dense(1)
    ])
    model.assemble()

    sess = Session(
        model,
        loss=losses.Mse(),
        optimizer=optimizers.Fixed(),
    )

    history = sess.fit(ds,
                       200000,
                       val_data=val_ds,
                       val_epochs=1000,
                       listeners=[
                           FitListener('val_end',
                                       callback=lambda h: on_val_end(sess, h))
                       ])

    fit_report(history, report_path + '02.png', 15)
コード例 #2
0
    def setUpClass(cls):
        cls.ly_dense = nn.Dense(4, inshape=2, activation='relu')
        cls.ly_dropout = nn.Dropout(keep_prob=0.5)
        cls.ly_dense1 = nn.Dense(6)

        cls.ly_dropout.join(cls.ly_dense)
        cls.ly_dense1.join(cls.ly_dropout)
コード例 #3
0
def fit_lstm():
    print("fit lstm")
    model = Model([
                rnn.Embedding(64, vocab_size+1),
                wrapper.Bidirectional(rnn.LSTM(128), rnn.LSTM(128)),
                nn.Filter(),
                nn.Dense(64),
                nn.Dropout(0.5),
                nn.Dense(1, activation='linear')
            ])
    model.assemble()
    fit('lstm', model)
コード例 #4
0
def fit_gru():
    print("fit gru")
    model = Model([
                rnn.Embedding(64, vocab_size+1),
                wrapper.Bidirectional(rnn.GRU(64), rnn.GRU(64)),
                nn.Filter(),
                nn.Dense(64),
                nn.Dropout(0.5),
                nn.Dense(1, activation='linear')
            ])
    model.assemble()
    fit('gru', model)
コード例 #5
0
def fit_gru():
    vocab_size = vocab.size()
    print("vocab size: ", vocab_size)
    model = Model([
                rnn.Embedding(256, vocab_size),
                rnn.GRU(1024, stateful=True),
                nn.Dense(1024),
                nn.Dropout(0.5),
                nn.Dense(vocab_size, activation='linear')
            ])

    model.assemble()
    fit("gru", model)
コード例 #6
0
ファイル: ut_nn_layers.py プロジェクト: brandonlyg/cute-dl
    def setUpClass(cls):
        cls.ly_dense = nn.Dense(4, inshape=2, activation='relu')
        cls.ly_dropout = nn.Dropout(keep_prob=0.5)
        cls.ly_dense1 = nn.Dense(6)

        root = RootLayer()

        cls.ly_dropout.set_prev(cls.ly_dense)
        cls.ly_dense1.set_prev(cls.ly_dropout)

        cls.ly_dense.set_parent(root)
        cls.ly_dropout.set_parent(root)
        cls.ly_dense1.set_parent(root)

        cls.ly_dense.init_params()
        cls.ly_dropout.init_params()
        cls.ly_dense1.init_params()
コード例 #7
0
def fit(name, optimizer):
    inshape = ds_train.data.shape[1:]
    #pdb.set_trace()
    model = Model([
        cnn.Conv2D(32, (3, 3), inshape=inshape),
        cnn.MaxPool2D((2, 2), strides=(2, 2)),
        cnn.Conv2D(64, (3, 3)),
        cnn.MaxPool2D((2, 2), strides=(2, 2)),
        cnn.Conv2D(64, (3, 3)),
        nn.Flatten(),
        nn.Dense(64),
        nn.Dropout(0.5),
        nn.Dense(10)
    ])
    model.assemble()

    sess = Session(model,
                   loss=losses.CategoricalCrossentropy(),
                   optimizer=optimizer)

    stop_fit = session.condition_callback(lambda: sess.stop_fit(), 'val_loss',
                                          30)

    accuracy = lambda h: fit_tools.accuracy(sess, ds_test, h)

    def save_and_report(history):
        #pdb.set_trace()
        fit_tools.fit_report(history, report_path + name + ".png")
        model.save(model_path + name)

    #pdb.set_trace()
    history = sess.fit(ds_train,
                       200,
                       val_data=ds_test,
                       val_steps=100,
                       listeners=[
                           stop_fit,
                           session.FitListener('val_end', callback=accuracy),
                           session.FitListener('val_end',
                                               callback=save_and_report)
                       ])

    save_and_report(history)