예제 #1
0
파일: rnn_lang.py 프로젝트: chagri/docvec
def main():
    data_root = "../data/wordinds/"
    print "starting L matrix construction"
    gw = GloveWrapper(verbose=True)
    L0 = gw.L

    train_files = os.listdir(data_root + "/train/")

    print "getting train and test data"
    train_x, train_y, train_D, _ = get_data(gw, train_files, "train/")
    pdb.set_trace()
    test_x, test_y, test_D, _ = get_data(gw, train_files, "test/")
    D0 = np.random.randn(train_D[-1] + 1, 300)
    print "got train and test data"

    n_epochs = 25
    train_x = train_x * n_epochs
    train_y = train_y * n_epochs
    train_D = train_D * n_epochs

    model = SimpleDRNNLM(L0, D0, U0=L0, alpha=0.05, rseed=10, bptt=3)
    print "constructed model, training..."
    model.custom_train_sgd(train_x, train_y, train_D, apply_to=["H", "U", "L", "D"], printevery=5, costevery=25)
    print "training done"

    print "saving model"
    with open("../data/simple_drnnlm_model.pkl", "w") as model_file:
        pkl.dump(model, model_file)
    print "model saved"

    for i in range(10):
        seq, J = model.generate_sequence(1, gw.get_index("SSTART"), gw.get_index("EEND"), maxlen=100)
        print " ".join(seq_to_words(seq))
예제 #2
0
def main():
    data_root = '../data/wordinds/'
    print 'starting L matrix construction'
    gw = GloveWrapper(verbose=True)
    L0 = gw.L

    train_files = os.listdir(data_root+ '/train/')

    print 'getting train and test data'
    train_x, train_y, train_D, _ = get_data(gw, train_files, 'train/') 
    pdb.set_trace()
    test_x, test_y, test_D, _ = get_data(gw, train_files, 'test/') 
    D0 = np.random.randn(train_D[-1] + 1, 300)
    print 'got train and test data'

    n_epochs = 25
    train_x = train_x*n_epochs
    train_y = train_y*n_epochs
    train_D = train_D*n_epochs

    model = SimpleDRNNLM(L0, D0, U0 = L0, alpha=0.05, rseed=10, bptt=3)
    print 'constructed model, training...'
    model.custom_train_sgd(train_x,train_y, train_D, apply_to=['H','U','L','D'], printevery=5, costevery=25)
    print 'training done'

    print 'saving model'
    with open('../data/simple_drnnlm_model.pkl', 'w') as model_file:
        pkl.dump(model, model_file) 
    print 'model saved'

    for i in range(10):
        seq, J = model.generate_sequence(1, gw.get_index("SSTART"), gw.get_index("EEND"), maxlen=100)
        print " ".join(seq_to_words(seq))
예제 #3
0
파일: grad_test.py 프로젝트: afgiel/docvec
import numpy as np
from simple_drnnlm import SimpleDRNNLM

wv_dummy = np.random.randn(10,50)
dv_dummy = np.random.randn(1, 50)*.01
model = SimpleDRNNLM(L0 = wv_dummy, D0 = dv_dummy, U0 = wv_dummy, alpha=0.005, rseed=10, bptt=4)
model.grad_check(np.array([1,2,3]), np.array([2,3,4]), [0])

model.generate_docvecs([np.array([1,2,3])], [np.array([2,3,4])], [0], dv_dummy)

from drnnlm import DRNNLM

wv_dummy = np.random.randn(10,50)
dv_dummy = np.random.randn(1, 50)*.01
model = DRNNLM(L0 = wv_dummy, D0 = dv_dummy, U0 = wv_dummy, alpha=0.005, rseed=10, bptt=4)
model.grad_check(np.array([1,2,3]), np.array([2,3,4]), [0])

model.generate_docvecs([np.array([1,2,3])], [np.array([2,3,4])], [0], dv_dummy)


예제 #4
0
파일: grad_test.py 프로젝트: chagri/docvec
import numpy as np
from simple_drnnlm import SimpleDRNNLM

wv_dummy = np.random.randn(10, 50)
dv_dummy = np.random.randn(1, 50) * .01
model = SimpleDRNNLM(L0=wv_dummy,
                     D0=dv_dummy,
                     U0=wv_dummy,
                     alpha=0.005,
                     rseed=10,
                     bptt=4)
model.grad_check(np.array([1, 2, 3]), np.array([2, 3, 4]), [0])

model.generate_docvecs([np.array([1, 2, 3])], [np.array([2, 3, 4])], [0],
                       dv_dummy)

from drnnlm import DRNNLM

wv_dummy = np.random.randn(10, 50)
dv_dummy = np.random.randn(1, 50) * .01
model = DRNNLM(L0=wv_dummy,
               D0=dv_dummy,
               U0=wv_dummy,
               alpha=0.005,
               rseed=10,
               bptt=4)
model.grad_check(np.array([1, 2, 3]), np.array([2, 3, 4]), [0])

model.generate_docvecs([np.array([1, 2, 3])], [np.array([2, 3, 4])], [0],
                       dv_dummy)