Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--file", help="the text file to train on", default='anna.txt')
    parser.add_argument("-p", "--prime", help="the prime message to begin with for generation", required=True)
    parser.add_argument("-n", "--count", help="the number of characters of generated text", type=int, required=True)
    parser.add_argument("-s", "--seq_length", help="the text file to train on", type=int, default=120)
    args = parser.parse_args()
    file = args.file
    prime = args.prime
    count = args.count
    seq_length = args.seq_length

    log.info("generating text of size {} from file {} with prime '{}' ...".format(count, file, prime))

    text = prep.load_data(path=path.join('data', file))

    network = CharLSTM(text, n_hidden=512, n_layers=3)
    if network.already_trained():
        network.load()
    else:
        train(network, epochs=40, batch_size=128, seq_length=seq_length, print_every=50)

    log.info("generating text now:")
    generated_text = network.generate(count, prime=prime)
    print(generated_text)
Пример #2
0
def train_epoc(trn, model, criterion, optimizer, scheduler):

    # Train the model
    model.train()
    train_loss = 0
    train_acc = 0

    for label, data_x, data_len in tqdm(trn):
        #print("label",label)
        #print("data_x",data_x)
        #print("data_len",data_len)
        optimizer.zero_grad()
        label = label.to(device)
        data_len = data_len.to(device)
        for i in range(len(data_x)):
            #data_x[i]=rnn_utils.pack_padded_sequence(data_x[i].to(device),data_len,batch_first=True,enforce_sorted=False)
            data_x[i] = data_x[i].to(device)
        output = model(data_x, data_len)
        loss = criterion(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == label).sum().item()

    # Adjust the learning rate train_acc
    scheduler.step()

    return train_loss, train_acc
Пример #3
0
 def test_should_train_network_on_cpu(self):
     model.on_gpu = False
     net = CharLSTM(sample_text)
     val_loss = train(net,
                      epochs=2,
                      batch_size=2,
                      seq_length=2,
                      print_every=4)
     self.assertTrue(val_loss > 0)
Пример #4
0
import logging as log
import os
import unittest

from lstm import model
from lstm.model import CharLSTM, train

sample_text = "hello world! shall we begin? let's go."
trained_model = CharLSTM("hello world! shall we begin? let's go.")
loss = train(trained_model,
             epochs=2,
             batch_size=2,
             seq_length=2,
             print_every=4)


class PreProcessingTest(unittest.TestCase):
    def test_gpu_should_be_available(self):
        self.assertTrue(model.on_gpu)

    def test_should_create_LSTM_by_default(self):
        net = CharLSTM("hello")
        self.assertEqual(0.001, net.lr)
        self.assertEqual(4, len(net.char2int))
        self.assertEqual(net.n_hidden, net.fc.in_features)
        self.assertEqual(4, net.lstm.input_size)
        self.assertEqual(512, net.lstm.hidden_size)
        self.assertEqual(net.n_layers, net.lstm.num_layers)
        self.assertEqual(0.5, net.lstm.dropout)
        log.info("LSTM layers: {}".format(net.lstm))