Пример #1
0
def train(config):
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Initialize the device which to run the model on
    device = torch.device(config.device)

    writer = SummaryWriter()

    seq_length = config.seq_length
    batch_size = config.batch_size
    lstm_num_hidden = config.lstm_num_hidden
    lstm_num_layers = config.lstm_num_layers
    dropout_keep_prob = config.dropout_keep_prob

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, seq_length)
    data_loader = DataLoader(dataset, batch_size, num_workers=1)

    vocab_size = dataset.vocab_size

    # Initialize the model that we are going to use
    model = TextGenerationModel(batch_size, seq_length, vocab_size,
                                lstm_num_hidden, lstm_num_layers,
                                dropout_keep_prob, device)
    model.to(device)

    # Setup the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             config.learning_rate_step,
                                             config.learning_rate_decay)

    for step, (batch_inputs, batch_targets) in enumerate(data_loader):

        # Only for time measurement of step through network
        t1 = time.time()

        #######################################################
        # Add more code here ...
        #######################################################

        # To onehot represetation of input or embedding => decided for embedding
        # batch_inputs = F.one_hot(batch_inputs, vocab_size).type(torch.FloatTensor).to(device)
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        train_output, _ = model.forward(batch_inputs)

        loss = criterion(train_output, batch_targets)
        accuracy = torch.sum(
            torch.eq(torch.argmax(train_output, dim=1),
                     batch_targets)).item() / (batch_targets.size(0) *
                                               batch_targets.size(1))

        writer.add_scalar('Loss/train', loss.item(), step)
        writer.add_scalar('Accuracy/train', accuracy, step)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step(step)

        # Just for time measurement
        t2 = time.time()
        examples_per_second = config.batch_size / float(t2 - t1)

        if step % config.print_every == 0:
            print(
                "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                "Accuracy = {:.2f}, Loss = {:.3f}".format(
                    datetime.now().strftime("%Y-%m-%d %H:%M"), step,
                    int(config.train_steps), config.batch_size,
                    examples_per_second, accuracy, loss))

        if step % config.sample_every == 0:
            # Generate some sentences by sampling from the model
            sample_from_model(config, step, model, dataset)

        if step == config.train_steps:
            # If you receive a PyTorch data-loader error, check this bug report:
            # https://github.com/pytorch/pytorch/pull/9655
            break

    print('Done training.')
    torch.save(model, "trained_model_part2.pth")
    writer.close()
Пример #2
0
def train(config):
    # Initialize the device which to run the model on
    # device = torch.device(config.device)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = TextDataset(filename=config.txt_file,
                          seq_length=config.seq_length)
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    VOCAB_SIZE = dataset.vocab_size
    CHAR2IDX = dataset._char_to_ix
    IDX2CHAR = dataset._ix_to_char

    # Initialize the model that we are going to use
    model = TextGenerationModel(batch_size=config.batch_size,
                                seq_length=config.seq_length,
                                vocabulary_size=VOCAB_SIZE,
                                lstm_num_hidden=config.lstm_num_hidden,
                                lstm_num_layers=config.lstm_num_layers,
                                device=device)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    scheduler = scheduler_lib.StepLR(optimizer=optimizer,
                                     step_size=config.learning_rate_step,
                                     gamma=config.learning_rate_decay)

    if True:
        model.load_state_dict(
            torch.load('grimm-results/intermediate-model-epoch-30-step-0.pth',
                       map_location='cpu'))
        optimizer.load_state_dict(
            torch.load("grimm-results/intermediate-optim-epoch-30-step-0.pth",
                       map_location='cpu'))

        print("Loaded it!")

    model = model.to(device)

    EPOCHS = 50

    for epoch in range(EPOCHS):
        # initialization of state that's given to the forward pass
        # reset every epoch
        h, c = model.reset_lstm(config.batch_size)
        h = h.to(device)
        c = c.to(device)

        for step, (batch_inputs, batch_targets) in enumerate(data_loader):

            # Only for time measurement of step through network
            t1 = time.time()

            model.train()

            optimizer.zero_grad()

            x = torch.stack(batch_inputs, dim=1).to(device)

            if x.size()[0] != config.batch_size:
                print("We're breaking because something is wrong")
                print("Current batch is of size {}".format(x.size()[0]))
                print("Supposed batch size is {}".format(config.batch_size))
                break

            y = torch.stack(batch_targets, dim=1).to(device)

            x = one_hot_encode(x, VOCAB_SIZE)

            output, (h, c) = model(x=x, prev_state=(h, c))

            loss = criterion(output.transpose(1, 2), y)

            accuracy = calculate_accuracy(output, y)
            h = h.detach()
            c = c.detach()
            loss.backward()
            # add clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=config.max_norm)
            optimizer.step()
            scheduler.step()

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if step % config.print_every == 0:
                #TODO FIX THIS PRINTING
                print(
                    f"Epoch {epoch} Train Step {step}/{config.train_steps}, Examples/Sec = {examples_per_second}, Accuracy = {accuracy}, Loss = {loss}"
                )
                #
                # print("[{}]".format(datetime.now().strftime("%Y-%m-%d %H:%M")))
                # print("[{}] Train Step {:04f}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, Accuracy = {:.2f}, Loss = {:.3f}".format(
                #     datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss
                # ))

                # print(loss)

            if step % config.sample_every == 0:
                FIRST_CHAR = 'I'  # Is randomized within the prediction, actually
                predict(device, model, FIRST_CHAR, VOCAB_SIZE, IDX2CHAR,
                        CHAR2IDX)
                # Generate some sentences by sampling from the model
                path_model = 'intermediate-model-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                path_optimizer = 'intermediate-optim-epoch-{}-step-{}.pth'.format(
                    epoch, step)
                torch.save(model.state_dict(), path_model)
                torch.save(optimizer.state_dict(), path_optimizer)

            if step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                break

    print('Done training.')
Пример #3
0
def train(config):

    # Initialize the device which to run the model on
    #device = torch.device(config.device)

    # Initialize the dataset and data loader (note the +1)
    dataset = TextDataset(config.txt_file, config.seq_length)  # fixme
    data_loader = DataLoader(dataset, config.batch_size, num_workers=1)

    #print(dataset._char_to_ix) vocabulary order changes, but batches are same sentence examples with the seeds earlier.

    # Initialize the model that we are going to use
    model = TextGenerationModel(config.batch_size, config.seq_length,
                                dataset.vocab_size, config.lstm_num_hidden,
                                config.lstm_num_layers, config.device)  # fixme

    device = model.device
    model = model.to(device)

    # Setup the loss and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(),
                                    lr=config.learning_rate)
    print("Len dataset:", len(dataset))
    print("Amount of steps for dataset:", len(dataset) / config.batch_size)

    current_step = 0
    not_max = True

    list_train_acc = []
    list_train_loss = []
    acc_average = []
    loss_average = []

    file = open("sentences.txt", 'w', encoding='utf-8')
    '''
    file_greedy = open("sentences_greedy.txt",'w',encoding='utf-8')
    file_tmp_05 = open("sentences_tmp_05.txt", 'w', encoding='utf-8')
    file_tmp_1 = open("sentences_tmp_1.txt", 'w', encoding='utf-8')
    file_tmp_2 = open("sentences_tmp_2.txt", 'w', encoding='utf-8')
    '''

    while not_max:

        for (batch_inputs, batch_targets) in data_loader:

            # Only for time measurement of step through network
            t1 = time.time()

            #######################################################
            # Add more code here ...

            #List of indices from word to ID, that is in dataset for embedding
            #Embedding lookup
            embed = model.embed  #Embeding shape(dataset.vocab_size, config.lstm_num_hidden)

            #Preprocess input to embeddings to give to LSTM all at once
            all_embed = []
            #sentence = []
            for batch_letter in batch_inputs:
                batch_letter_to = batch_letter.to(
                    device)  #torch.tensor(batch_letter,device = device)
                embedding = embed(batch_letter_to)
                all_embed.append(embedding)

                #sentence.append(batch_letter_to[0].item())
            all_embed = torch.stack(all_embed)

            #Print first example sentence of batch along with target
            #print(dataset.convert_to_string(sentence))
            #sentence = []
            #for batch_letter in batch_targets:
            #    sentence.append(batch_letter[0].item())
            #print(dataset.convert_to_string(sentence))

            all_embed = all_embed.to(device)
            outputs = model(
                all_embed
            )  #[30,64,vocab_size] 87 last dimension for fairy tails

            #######################################################

            #loss = np.inf   # fixme
            #accuracy = 0.0  # fixme

            #For loss: ensuring that the prediction dim are batchsize x vocab_size x sequence length and targets: batchsize x sequence length
            batch_first_output = outputs.transpose(0, 1).transpose(1, 2)
            batch_targets = torch.stack(batch_targets).to(device)
            loss = criterion(batch_first_output, torch.t(batch_targets))

            #Backpropagate
            model.zero_grad()
            loss.backward()
            loss = loss.item()
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          max_norm=config.max_norm)
            optimizer.step()

            #Accuracy
            number_predictions = torch.argmax(outputs, dim=2)
            result = number_predictions == batch_targets
            accuracy = result.sum().item() / (batch_targets.shape[0] *
                                              batch_targets.shape[1])
            ''''
            #Generate sentences for all settings on every step
            sentence_id = model.generate_sentence(config.gsen_length, -1)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_greedy.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 0.5)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_05.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 1)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_1.write( (str(current_step) + ": " + sentence + "\n"))

            sentence_id = model.generate_sentence(config.gsen_length, 2)
            sentence = dataset.convert_to_string(sentence_id)
            #print(sentence)
            file_tmp_2.write( (str(current_step) + ": " + sentence + "\n"))
            '''

            if config.measure_type == 2:
                acc_average.append(accuracy)
                loss_average.append(loss)

            # Just for time measurement
            t2 = time.time()
            examples_per_second = config.batch_size / float(t2 - t1)

            if current_step % config.print_every == 0:

                # Average accuracy and loss over the last print every step (5 by default)
                if config.measure_type == 2:
                    accuracy = sum(acc_average) / config.print_every
                    loss = sum(loss_average) / config.print_every
                    acc_average = []
                    loss_average = []

                # Either accuracy and loss on the print every interval or the average of that interval as stated above
                list_train_acc.append(accuracy)
                list_train_loss.append(loss)

                print(
                    "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, "
                    "Accuracy = {:.2f}, Loss = {:.3f}".format(
                        datetime.now().strftime("%Y-%m-%d %H:%M"),
                        current_step, config.train_steps, config.batch_size,
                        examples_per_second, accuracy, loss))
            elif config.measure_type == 0:
                # Track accuracy and loss for every step
                list_train_acc.append(accuracy)
                list_train_loss.append(loss)

            if current_step % config.sample_every == 0:
                # Generate sentence
                sentence_id = model.generate_sentence(config.gsen_length,
                                                      config.temperature)
                sentence = dataset.convert_to_string(sentence_id)
                print(sentence)
                file.write((str(current_step) + ": " + sentence + "\n"))

            if current_step == config.train_steps:
                # If you receive a PyTorch data-loader error, check this bug report:
                # https://github.com/pytorch/pytorch/pull/9655
                not_max = False
                break

            current_step += 1

    # Close the file and make sure sentences en measures are saved
    file.close()
    pickle.dump((list_train_acc, list_train_loss),
                open("loss_and_train.p", "wb"))

    #Plot
    print(len(list_train_acc))

    if config.measure_type == 0:
        eval_steps = list(range(config.train_steps + 1))  # Every step Acc
    else:  #
        eval_steps = list(
            range(0, config.train_steps + config.print_every,
                  config.print_every))

    if config.measure_type == 2:
        plt.plot(eval_steps[:-1], list_train_acc[1:], label="Train accuracy")
    else:
        plt.plot(eval_steps, list_train_acc, label="Train accuracy")

    plt.xlabel("Step")
    plt.ylabel("Accuracy")
    plt.title("Training accuracy LSTM", fontsize=18, fontweight="bold")
    plt.legend()
    # plt.savefig('accuracies.png', bbox_inches='tight')
    plt.show()

    if config.measure_type == 2:
        plt.plot(eval_steps[:-1], list_train_loss[1:], label="Train loss")
    else:
        plt.plot(eval_steps, list_train_loss, label="Train loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training loss LSTM", fontsize=18, fontweight="bold")
    plt.legend()
    # plt.savefig('loss.png', bbox_inches='tight')
    plt.show()
    print('Done training.')