예제 #1
0
    def testLoad(self):
        vdict, idict = get_dict()

        ds = S2SDataSet(vdict, idict, 'bobsue-data/bobsue.seq2seq.dev.tsv')

        for batch in ds.batches(30):
            self.assertEqual(2, len(batch.data))
            self.assertEqual(batch.data[0].shape[0], batch.data[1].shape[0])
예제 #2
0
from common_train import Trainer
from lm_loss import LogLoss
from lstm_dataset import S2SDataSet
from lstm_graph import BiLSTMEncodeGraph
from ndnn.sgd import Adam
from vocab_dict import get_dict

vocab_dict, idx_dict = get_dict()

train_ds = S2SDataSet(vocab_dict, idx_dict, "bobsue-data/bobsue.seq2seq.train.tsv")
dev_ds = S2SDataSet(vocab_dict, idx_dict, "bobsue-data/bobsue.seq2seq.dev.tsv")
test_ds = S2SDataSet(vocab_dict, idx_dict, "bobsue-data/bobsue.seq2seq.test.tsv")

dict_size = len(vocab_dict)
hidden_dim = 200
batch_size = 50

trainer = Trainer()
graph = BiLSTMEncodeGraph(LogLoss(), Adam(eta=0.001, decay=0.99), dict_size, hidden_dim)
trainer.train(idx_dict, 100, 's2s_bilstm', graph, train_ds, dev_ds, test_ds, 50)
예제 #3
0
import numpy as np

from lm_loss import LogLoss
from lstm_dataset import S2SDataSet
from lstm_graph import BiLSTMDecodeGraph,AttentionDecodeGraph
from ndnn.dataset import Batch
from ndnn.store import ParamStore
from vocab_dict import get_dict, translate

vocab_dict, idx_dict = get_dict()

dev_ds = S2SDataSet(vocab_dict, idx_dict, "bobsue-data/bobsue.seq2seq.dev.tsv")

dict_size = len(vocab_dict)
hidden_dim = 200
batch_size = 50

#graph = BiLSTMDecodeGraph(LogLoss(), dict_size, hidden_dim, 50)
graph = AttentionDecodeGraph(LogLoss(),dict_size,hidden_dim,50)
#store = ParamStore("model/s2s_bilstm.mdl")
store = ParamStore("model/s2s_attention.mdl")
graph.load(store.load())

num_sample = 10

for i in range(num_sample):
    gi = np.random.randint(0, len(dev_ds.datas))
    group = dev_ds.datas[gi]

    ii = np.random.randint(0, len(group))
    data = np.int32(group[ii][0]).reshape([1, -1])
예제 #4
0
import numpy as np

from lm_loss import LogLoss
from lstm_dataset import LSTMDataSet, S2SDataSet
from lstm_graph import LSTMEncodeGraph, BiLSTMEncodeGraph, BowEncodeGraph
from ndnn.dataset import Batch
from ndnn.sgd import Adam
from ndnn.store import ParamStore
from vocab_dict import get_dict, translate

vocab_dict, idx_dict = get_dict()

lmdev_ds = LSTMDataSet(vocab_dict, idx_dict, "bobsue-data/bobsue.lm.dev.txt")
s2strain_ds = S2SDataSet(vocab_dict, idx_dict,
                         "bobsue-data/bobsue.seq2seq.train.tsv")

dict_size = len(vocab_dict)
hidden_dim = 200
batch_size = 50

lstm_encode_graph = LSTMEncodeGraph(LogLoss(), Adam(eta=0.001), dict_size,
                                    hidden_dim)
lstm_encode_store = ParamStore("model/s2s_lstm.mdl")
lstm_encode_graph.load(lstm_encode_store.load())

bilstm_encode_graph = BiLSTMEncodeGraph(LogLoss(), Adam(eta=0.001), dict_size,
                                        hidden_dim)
bilstm_encode_store = ParamStore("model/s2s_bilstm.mdl")
bilstm_encode_graph.load(bilstm_encode_store.load())

bow_encode_graph = BowEncodeGraph(LogLoss(), Adam(eta=0.001), dict_size,