コード例 #1
0
ファイル: train.py プロジェクト: satopirka/nlp-nnabla
x, t, accuracy, loss = build_self_attention_model(train=True)
solver = S.Adam()
solver.set_parameters(nn.get_parameters())

x, t, accuracy, loss = build_self_attention_model(train=True)
trainer = Trainer(inputs=[x, t],
                  loss=loss,
                  metrics={
                      'cross entropy': loss,
                      'accuracy': accuracy
                  },
                  solver=solver)
for epoch in range(max_epoch):
    x, t, accuracy, loss = build_self_attention_model(train=True)
    trainer.update_variables(inputs=[x, t],
                             loss=loss,
                             metrics={
                                 'cross entropy': loss,
                                 'accuracy': accuracy
                             })
    trainer.run(train_data_iter, None, epochs=1, verbose=1)

    x, t, accuracy, loss = build_self_attention_model(train=False)
    trainer.update_variables(inputs=[x, t],
                             loss=loss,
                             metrics={
                                 'cross entropy': loss,
                                 'accuracy': accuracy
                             })
    trainer.evaluate(dev_data_iter, verbose=1)
コード例 #2
0
ファイル: train.py プロジェクト: satopirka/nlp-nnabla
# Create solver.
solver = S.Momentum(1e-2, momentum=0.9)
solver.set_parameters(nn.get_parameters())

x, t, loss = build_model(train=True)
trainer = Trainer(inputs=[x, t], loss=loss, metrics={'PPL': np.e**loss}, solver=solver, save_path='char-cnn-lstmlm')
trainer.run(train_data_iter, valid_data_iter, epochs=max_epoch)

for epoch in range(max_epoch):
    x, t, loss = build_model(train=True)
    trainer.update_variables(inputs=[x, t], loss=loss, metrics={'PPL': np.e**loss})
    trainer.run(train_data_iter, None, epochs=1, verbose=1)
    
    x, t, loss = build_model(train=False)
    trainer.update_variables(inputs=[x, t], loss=loss, metrics={'PPL': np.e**loss})
    trainer.evaluate(valid_data_iter, verbose=1)

# nn.load_parameters('char-cnn-lstm_best.h5')

# batch_size = 1
# sentence_length = 1
# x, embeddings = build_model(get_embeddings=True)

# W = np.zeros((len(w2i), sum(filters)))
# for i, word in enumerate(w2i):
#     vec = wordseq2charseq([[w2i[word]]])
#     x.d = vec
#     embeddings.forward(clear_no_need_grad=True)
#     W[w2i[word], :] = embeddings.d[0][0]

# def get_word_from_id(id):